diff options
Diffstat (limited to 'sql/core')
85 files changed, 1402 insertions, 1325 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 189cc3972c..f2ae40e644 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -28,14 +28,13 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType; // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - SparkContext context = new SparkContext("local[*]", "testing"); - javaCtx = new JavaSparkContext(context); - sqlContext = new SQLContext(context); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - javaCtx = null; + spark.stop(); + spark = null; } public static class Person implements Serializable { @@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) throws Exception { @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList(); + List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) { @@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(new Function<Row, String>() { @Override public String call(Row row) { @@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void applySchemaToJSON() { - JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList( + JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", @@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - Dataset<Row> df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = spark.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1eb680dc4c..324ebbae38 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,12 +20,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; +import java.util.*; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import org.junit.*; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.BloomFilter; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } @Test public void testExecution() { - Dataset<Row> df = context.table("testData").filter("key = 1"); + Dataset<Row> df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test public void testCollectAndTake() { - Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -83,7 +77,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.toDF("key1", "value1"); @@ -112,7 +106,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - Dataset<Row> df2 = context.table("testData2"); + Dataset<Row> df2 = spark.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -126,7 +120,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.show(); df.show(1000); } @@ -194,7 +188,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - Dataset<Row> df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = spark.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -202,7 +196,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - Dataset<Row> df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = spark.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -210,7 +204,7 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - Dataset<Row> df = context.createDataFrame(rows, schema); + Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); } @@ -239,7 +233,7 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); @@ -258,7 +252,7 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); String[] cols = {"a"}; Dataset<Row> results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); @@ -266,21 +260,21 @@ public class JavaDataFrameSuite { @Test public void testCorrelation() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); @@ -291,7 +285,7 @@ public class JavaDataFrameSuite { @Test public void pivot() { - Dataset<Row> df = context.table("courseSales"); + Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); @@ -324,10 +318,10 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt")); + Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt")); Assert.assertEquals(4L, df1.count()); - Dataset<Row> df2 = context.read().format("text").load( + Dataset<Row> df2 = spark.read().format("text").load( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, df2.count()); @@ -335,10 +329,10 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - Dataset<String> ds1 = context.read().text(getResource("text-suite.txt")); + Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt")); Assert.assertEquals(4L, ds1.count()); - Dataset<String> ds2 = context.read().text( + Dataset<String> ds2 = spark.read().text( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); @@ -346,7 +340,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -371,7 +365,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f1b1c22e4a..8354a5bdac 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,46 +23,43 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; -import com.google.common.base.Objects; -import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; +import com.google.common.base.Objects; import org.junit.*; +import org.junit.rules.ExpectedException; import org.apache.spark.Accumulator; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testToLocalIterator() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Iterator<String> iter = ds.toLocalIterator(); Assert.assertEquals("hello", iter.next()); Assert.assertEquals("world", iter.next()); @@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction<String>() { @Override @@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override @@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey( new MapFunction<String, Integer>() { @Override @@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { toSet(reduced.collectAsList())); List<Integer> data2 = Arrays.asList(2, 6, 10); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey( new MapFunction<Integer, Integer>() { @Override @@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), @@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); + Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable { public void testTupleEncoder() { Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); - Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); + Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder<Tuple3<Integer, Long, String>> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<>(1, 2L, "a")); - Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); + Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); - Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); + Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = @@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = - context.createDataset(data5, encoder5); + spark.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); } @@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); - Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) @@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = - context.createDataset(data2, encoder2); + spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) @@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = - context.createDataset(data3, encoder3); + spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } @@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable { 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds = - context.createDataset(data, encoder); + spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable { Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class); List<KryoSerializable> data = Arrays.asList( new KryoSerializable("hello"), new KryoSerializable("world")); - Dataset<KryoSerializable> ds = context.createDataset(data, encoder); + Dataset<KryoSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable { Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class); List<JavaSerializable> data = Arrays.asList( new JavaSerializable("hello"), new JavaSerializable("world")); - Dataset<JavaSerializable> ds = context.createDataset(data, encoder); + Dataset<JavaSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @Test public void testRandomSplit() { List<String> data = Arrays.asList("hello", "world", "from", "spark"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); double[] arraySplit = {1, 2, 3}; List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1); @@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable { obj2.setF(Arrays.asList(300L, null, 400L)); List<SimpleJavaBean> data = Arrays.asList(obj1, obj2); - Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List<NestedJavaBean> data2 = Arrays.asList(obj3); - Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow(new Object[]{ @@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable { .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); - Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } @@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable { obj.setB(new Date(0)); obj.setC(java.math.BigDecimal.valueOf(1)); Dataset<SimpleJavaBean2> ds = - context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } @@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4a78dca7fe..2274912521 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -24,33 +24,30 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaUDFSuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @SuppressWarnings("unchecked") @@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable { // sqlContext.registerFunction( // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() { + spark.udf().register("stringLengthTest", new UDF1<String, Integer>() { @Override public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); } @@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable { // (String str1, String str2) -> str1.length() + str2.length, // DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { + spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { @Override public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java index 7863177093..059c2d9f2c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -26,36 +26,30 @@ import scala.Tuple2; import org.junit.After; import org.junit.Before; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSparkSession; /** * Common test base shared across this and Java8DatasetAggregatorSuite. */ public class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; + private transient TestSparkSession spark; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable { Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List<Tuple2<String, Integer>> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder); return ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e65158eb0..d0435e4d43 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources; import java.io.File; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; @@ -37,8 +39,8 @@ import org.apache.spark.util.Utils; public class JavaSaveLoadSuite { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; File path; Dataset<Row> df; @@ -52,9 +54,11 @@ public class JavaSaveLoadSuite { @Before public void setUp() throws IOException { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -66,16 +70,15 @@ public class JavaSaveLoadSuite { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD<String> rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); + JavaRDD<String> rdd = jsc.parallelize(jsonObjects); + df = spark.read().json(rdd); df.registerTempTable("jsonTable"); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @Test @@ -83,7 +86,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -96,8 +99,8 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load(); - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 5ef20267f8..800316cde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -36,7 +36,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val plan = sqlContext.table(tableName).queryExecution.sparkPlan + val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cachedColumnBuffers.id @@ -73,41 +73,41 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - sqlContext.cacheTable("tempTable") + spark.catalog.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != sqlContext.cacheManager.lookupCachedData(testData)) + assert(None != spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + spark.catalog.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - sqlContext.cacheTable("tempTable1") + spark.catalog.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - sqlContext.uncacheTable("tempTable2") + spark.catalog.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -117,101 +117,101 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val data = "*" * 1000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(sqlContext.table("bigData").count() === 200000L) - sqlContext.table("bigData").unpersist(blocking = true) + spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(spark.table("bigData").count() === 200000L) + spark.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - sqlContext.table("testData").cache() - assertCached(sqlContext.table("testData")) - sqlContext.table("testData").unpersist(blocking = true) + spark.table("testData").cache() + assertCached(spark.table("testData")) + spark.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - sqlContext.table("testData").cache() - sqlContext.table("testData").count() - sqlContext.table("testData").unpersist(blocking = true) - assertCached(sqlContext.table("testData"), 0) + spark.table("testData").cache() + spark.table("testData").count() + spark.table("testData").unpersist(blocking = true) + assertCached(spark.table("testData"), 0) } test("isCached") { - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") - assertCached(sqlContext.table("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + assertCached(spark.table("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - sqlContext.uncacheTable("testData") - assert(!sqlContext.isCached("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + spark.catalog.uncacheTable("testData") + assert(!spark.catalog.isCached("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - sqlContext.cacheTable("testData") - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r }.size } - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("read from cached table and uncache") { - sqlContext.cacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData")) - sqlContext.uncacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData"), 0) + spark.catalog.uncacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - sqlContext.cacheTable("selectStar") + spark.catalog.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - sqlContext.uncacheTable("selectStar") + spark.catalog.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") + assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -228,14 +228,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -243,14 +243,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(sqlContext.table("testCacheTable")) + assertCached(spark.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - sqlContext.uncacheTable("testCacheTable") + spark.catalog.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,7 +258,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -270,7 +270,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -278,7 +278,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -287,62 +287,62 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("Drops temporary table") { testData.select('key).registerTempTable("t1") - sqlContext.table("t1") - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) + spark.table("t1") + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") - assert(sqlContext.isCached("t1")) - assert(sqlContext.isCached("t2")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) - assert(!sqlContext.isCached("t2")) + spark.catalog.dropTempTable("t1") + intercept[AnalysisException](spark.table("t1")) + assert(!spark.catalog.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") - sqlContext.clearCache() - assert(sqlContext.cacheManager.isEmpty) + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.clearCache() + assert(spark.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("Clear CACHE") - assert(sqlContext.cacheManager.isEmpty) + assert(spark.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId1 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId2 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") assert(AccumulatorContext.get(accId1).isEmpty) assert(AccumulatorContext.get(accId2).isEmpty) @@ -351,7 +351,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") - sqlContext.cacheTable("abc") + spark.catalog.cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from @@ -374,15 +374,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext table3x.registerTempTable("testData3x") sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") - sqlContext.cacheTable("orderedTable") - assertCached(sqlContext.table("orderedTable")) + spark.catalog.cacheTable("orderedTable") + assertCached(spark.table("orderedTable")) // Should not have an exchange as the query is already sorted on the group by key. verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) checkAnswer( sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) - sqlContext.uncacheTable("orderedTable") - sqlContext.dropTempTable("orderedTable") + spark.catalog.uncacheTable("orderedTable") + spark.catalog.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. @@ -390,8 +390,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(numPartitions, $"key").registerTempTable("t1") testData2.repartition(numPartitions, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") // Joining them should result in no exchanges. verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) @@ -403,8 +403,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), sql("SELECT count(*) FROM testData GROUP BY key")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } @@ -412,8 +412,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"key").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -421,16 +421,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Need to shuffle one side. withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -438,15 +438,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(12, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -454,8 +454,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Since the number of partitions of @@ -464,30 +464,30 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTempTable("t1", "t2") { testData.repartition(6, $"value").registerTempTable("t1") testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 2) checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // repartition's column ordering is different from group by column ordering. // But they use the same set of columns. withTempTable("t1") { testData.repartition(6, $"value", $"key").registerTempTable("t1") - sqlContext.cacheTable("t1") + spark.catalog.cacheTable("t1") val query = sql("SELECT value, key from t1 group by key, value") verifyNumExchanges(query, 0) checkAnswer( query, testData.distinct().select($"value", $"key")) - sqlContext.uncacheTable("t1") + spark.catalog.uncacheTable("t1") } // repartition's column ordering is different from join condition's column ordering. @@ -499,8 +499,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext df1.repartition(6, $"value", $"key").registerTempTable("t1") val df2 = testData2.select($"a", $"b".cast("string")) df2.repartition(6, $"a", $"b").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") @@ -509,8 +509,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 19fe29a202..a5aecca13f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - sqlContext.createDataFrame(sparkContext.parallelize( + spark.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -287,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -308,7 +308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -351,7 +351,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("=!=") { - val nullData = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData = spark.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -370,7 +370,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) - val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData2 = spark.createDataFrame(sparkContext.parallelize( Row("abc") :: Row(null) :: Row("xyz") :: Nil), @@ -596,7 +596,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) + val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 63f4b759a0..8a99866a33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) ) - val decimalDataWithNulls = sqlContext.sparkContext.parallelize( + val decimalDataWithNulls = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, null) :: DecimalData(2, 1) :: @@ -114,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 113000.0) :: Nil ) - val df0 = sqlContext.sparkContext.parallelize(Seq( + val df0 = spark.sparkContext.parallelize(Seq( Fact(20151123, 18, 35, "room1", 18.6), Fact(20151123, 18, 35, "room2", 22.4), Fact(20151123, 18, 36, "room1", 17.4), @@ -207,12 +207,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true) } test("agg without groups") { @@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( - sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( - sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0414fa1c91..031e66b57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => df1.write.parquet(path.getCanonicalPath) - val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + val pf1 = spark.read.parquet(path.getCanonicalPath) assert(df1.join(broadcast(pf1)).count() === 4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index c6d67519b0..fa8fa06907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -81,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ } test("pivot max values enforced") { - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1) intercept[AnalysisException]( courseSales.groupBy("year").pivot("course") ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } @@ -104,7 +104,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ // pivot with extra columns to trigger optimization .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) .agg(sum($"earnings")) - val queryExecution = sqlContext.executePlan(df.queryExecution.logical) + val queryExecution = spark.executePlan(df.queryExecution.logical) assert(queryExecution.simpleString.contains("pivotfirst")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0ea7727e45..ab7733b239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -236,7 +236,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) + val df = spark.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), @@ -247,7 +247,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. test("countMinSketch") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) assert(sketch1.totalCount() === 1000) @@ -279,7 +279,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // This test only verifies some basic requirements, more correctness tests can be found in // `BloomFilterSuite` in project spark-sketch. test("Bloom filter") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) @@ -304,7 +304,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin // Turn on this test if you want to test the performance of approximate quantiles. ignore("computing quantiles should not take much longer than describe()") { - val df = sqlContext.range(5000000L).toDF("col1").cache() + val df = spark.range(5000000L).toDF("col1").cache() def seconds(f: => Any): Double = { // Do some warmup logDebug("warmup...") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 80a93ee6d4..f77403c13e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -99,8 +99,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) val schema2 = StructType(Array(StructField("label", IntegerType, false), StructField("point", new ExamplePointUDT(), false))) - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) checkAnswer( df1.union(df2).orderBy("label"), @@ -109,8 +109,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("empty data frame") { - assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(sqlContext.emptyDataFrame.count() === 0) + assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(spark.emptyDataFrame.count() === 0) } test("head and take") { @@ -369,7 +369,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake checkAnswer( - sqlContext.range(2).toDF().limit(2147483638), + spark.range(2).toDF().limit(2147483638), Row(0) :: Row(1) :: Nil ) } @@ -672,12 +672,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val parquetDir = new File(dir, "parquet").getCanonicalPath df.write.parquet(parquetDir) - val parquetDF = sqlContext.read.parquet(parquetDir) + val parquetDF = spark.read.parquet(parquetDir) assert(parquetDF.inputFiles.nonEmpty) val jsonDir = new File(dir, "json").getCanonicalPath df.write.json(jsonDir) - val jsonDF = sqlContext.read.json(jsonDir) + val jsonDF = spark.read.json(jsonDir) assert(parquetDF.inputFiles.nonEmpty) val unioned = jsonDF.union(parquetDF).inputFiles.sorted @@ -801,7 +801,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.rdd.collect() } @@ -818,14 +818,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sparkContext.makeRDD( + val df2 = spark.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -881,53 +881,53 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = sqlContext.range(0, 10, 1, 15).select("id") + val res1 = spark.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = sqlContext.range(3, 15, 3, 2).select("id") + val res2 = spark.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = sqlContext.range(1, -2).select("id") + val res3 = spark.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = sqlContext.range(1, -2, -2, 6).select("id") + val res4 = spark.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = sqlContext.range(-3, -8, -2, 1).select("id") + val res5 = spark.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = sqlContext.range(-8, -4, 2, 1).select("id") + val res6 = spark.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = sqlContext.range(-10, -9, -20, 1).select("id") + val res7 = spark.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = sqlContext.range(10).select("id") + val res10 = spark.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = sqlContext.range(-1).select("id") + val res11 = spark.range(-1).select("id") assert(res11.count == 0) // using the default slice number - val res12 = sqlContext.range(3, 15, 3).select("id") + val res12 = spark.range(3, 15, 3).select("id") assert(res12.count == 4) assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } @@ -993,13 +993,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -1019,7 +1019,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(spark, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -1062,7 +1062,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1091,10 +1091,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(dir1, dir2), + checkAnswer(spark.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) - checkAnswer(sqlContext.read.format("json").load(dir1), + checkAnswer(spark.read.format("json").load(dir1), Row(1, 22) :: Nil) } } @@ -1116,7 +1116,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val input = spark.read.json(spark.sparkContext.makeRDD( (1 to 10).map(i => s"""{"id": $i}"""))) val df = input.select($"id", rand(0).as('r)) @@ -1185,7 +1185,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withTempPath { path => Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) - val df = sqlContext.read.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) } } @@ -1244,7 +1244,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { verifyExchangingAgg(testData.repartition($"key", $"value") .groupBy("key").count()) - val data = sqlContext.sparkContext.parallelize( + val data = spark.sparkContext.parallelize( (1 to 100).map(i => TestData2(i % 10, i))).toDF() // Distribute and order by. @@ -1308,7 +1308,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withTempPath { path => val p = path.getAbsolutePath Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) - checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012)) } } } @@ -1317,7 +1317,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( rdd, new StructType().add("f1", IntegerType).add("f2", IntegerType), needsConversion = false).select($"F1", $"f2".as("f2")) @@ -1344,7 +1344,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) - sqlContext.udf.register("boxedUDF", + spark.udf.register("boxedUDF", (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) @@ -1393,7 +1393,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("reuse exchange") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) @@ -1415,14 +1415,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("sameResult() on aggregate") { - val df = sqlContext.range(100) + val df = spark.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() // two aggregates with different ExprId within them should have same result assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) val agg3 = df.groupBy().sum() assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) - val df2 = sqlContext.range(101) + val df2 = spark.range(101) val agg4 = df2.groupBy().count() assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) } @@ -1454,24 +1454,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) } test("SPARK-13774: Check error message for non existent path without globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("csv"). + val e = intercept[AnalysisException] (spark.read.format("csv"). load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage() assert(e.startsWith("Path does not exist")) } test("SPARK-13774: Check error message for not existent globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("text"). + val e = intercept[AnalysisException] (spark.read.format("text"). load( "/xyz/*")).getMessage() assert(e.startsWith("Path does not exist")) - val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd). + val e1 = intercept[AnalysisException] (spark.read.json("/mnt/*/*-xyz.json").rdd). getMessage() assert(e1.startsWith("Path does not exist")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 06584ec21e..a957d5ba25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -249,14 +249,14 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B try { f(tableName) } finally { - sqlContext.dropTempTable(tableName) + spark.catalog.dropTempTable(tableName) } } test("time window in SQL with single string expression") { withTempTable { table => checkAnswer( - sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + spark.sql(s"""select window(time, "10 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), @@ -270,7 +270,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with two expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( @@ -285,7 +285,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with three expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 68e99d6a6b..fe6ba83b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b3", FloatType) .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(struct)) } @@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b5b", StringType)) .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index ae9fb80c68..d8e241c62f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -31,14 +31,14 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,17 +72,17 @@ object DatasetBenchmark { benchmark } - def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -130,13 +130,13 @@ object DatasetBenchmark { override def outputEncoder: Encoder[Long] = Encoders.scalaLong } - def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { - import sqlContext.implicits._ + def aggregate(spark: SparkSession, numRows: Long): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("aggregate", numRows) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => rdd.aggregate(0L)(_ + _.l, _ + _) } @@ -157,15 +157,17 @@ object DatasetBenchmark { } def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) + val spark = SparkSession.builder + .master("local[*]") + .appName("Dataset benchmark") + .getOrCreate() val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(sqlContext, numRows, numChains) - val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) - val benchmark3 = aggregate(sqlContext, numRows) + val benchmark = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilter(spark, numRows, numChains) + val benchmark3 = aggregate(spark, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 942cc09b6d..8c0906b746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -39,7 +39,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + assert(spark.cacheManager.lookupCachedData(cached).isEmpty, "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -56,9 +56,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds1).isEmpty, + "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds2).isEmpty, + "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -73,8 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + assert(spark.cacheManager.lookupCachedData(ds).isEmpty, "The Dataset ds should not be cached.") agged.unpersist() - assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + assert(spark.cacheManager.lookupCachedData(agged).isEmpty, + "The Dataset agged should not be cached.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3cb4e52c6d..3c8c862c22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -46,12 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("range") { - assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) } test("SPARK-12404: Datatype Helper Serializability") { @@ -472,7 +472,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14696: implicit encoders for boxed types") { - assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L) + assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L) } test("SPARK-11894: Incorrect results are returned when using null") { @@ -510,8 +510,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) def buildDataset(rows: Row*): Dataset[NestedStruct] = { - val rowRDD = sqlContext.sparkContext.parallelize(rows) - sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + val rowRDD = spark.sparkContext.parallelize(rows) + spark.createDataFrame(rowRDD, schema).as[NestedStruct] } checkDataset( @@ -626,7 +626,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { - val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) // Make sure the generated code for this plan can compile and execute. checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } @@ -654,7 +654,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { dataset.join(actual, dataset("user") === actual("id")).collect() } - test("SPARK-15097: implicits on dataset's sqlContext can be imported") { + test("SPARK-15097: implicits on dataset's spark can be imported") { val dataset = Seq(1, 2, 3).toDS() checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } @@ -735,10 +735,10 @@ object JavaData { def apply(a: Int): JavaData = new JavaData(a) } -/** Used to test importing dataset.sqlContext.implicits._ */ +/** Used to test importing dataset.spark.implicits._ */ object DatasetTransform { def addOne(ds: Dataset[Int]): Dataset[Int] = { - import ds.sqlContext.implicits._ + import ds.sparkSession.implicits._ ds.map(_ + 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index b1987c6908..a41b465548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { test("insert an extraStrategy") { try { - sqlContext.experimental.extraStrategies = TestStrategy :: Nil + spark.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( @@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { df.select("a", "b"), Row("so slow", 1)) } finally { - sqlContext.experimental.extraStrategies = Nil + spark.experimental.extraStrategies = Nil } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8cbad04e23..da567db5ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( @@ -112,7 +112,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // } test("broadcasted hash join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") Seq( ("SELECT * FROM testData join testData2 ON key = a", @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") sql("CACHE TABLE testData2") Seq( @@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -435,7 +435,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted existence join operator selection") { - sqlContext.cacheManager.clearCache() + spark.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { @@ -461,17 +461,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("cross join with broadcast") { sql("CACHE TABLE testData") - val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData")) // we set the threshold is greater than statistic of the cached table testData withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { - assert(statisticSizeInByte(sqlContext.table("testData2")) > - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData2")) > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) - assert(statisticSizeInByte(sqlContext.table("testData")) < - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData")) < + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 9f6c86a575..c88dfe5f24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,36 +33,36 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) } test("get all tables") { checkAnswer( - sqlContext.tables().filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"), + spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { + Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -81,9 +81,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "listtablessuitetable") ) checkAnswer( - sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - sqlContext.dropTempTable("tables") + spark.catalog.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala new file mode 100644 index 0000000000..1732977ee5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import org.scalatest.Suite + +/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */ +trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var spark: SparkSession = _ + + override def beforeAll() { + super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + } + + override def afterEach() { + try { + resetSparkContext() + } finally { + super.afterEach() + } + } + + def resetSparkContext(): Unit = { + LocalSparkSession.stop(spark) + spark = null + } + +} + +object LocalSparkSession { + def stop(spark: SparkSession) { + if (spark != null) { + spark.stop() + } + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index df8b3b7d87..a1a9b66c1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { - protected def sqlContext: SQLContext + protected def spark: SparkSession // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -267,7 +267,7 @@ abstract class QueryTest extends PlanTest { val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) } catch { case NonFatal(e) => fail( @@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest { def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession) + LogicalRDD(l.output, origin.rdd)(spark) case l: LocalRelation => val origin = localRelations.pop() l.copy(data = origin.data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1ff288cd19..e401abef29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -57,7 +57,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern) .map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions("*")) @@ -88,7 +88,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14415: All functions should have own descriptions") { - for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + for (f <- spark.sessionState.functionRegistry.listFunction()) { if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } @@ -102,7 +102,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - sqlContext.cacheTable("cachedData") + spark.catalog.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -193,7 +193,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sparkContext.parallelize( + spark.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -211,7 +211,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) .registerTempTable("d") @@ -258,9 +258,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen") { // Prepare a table that we can group some rows. - sqlContext.table("testData") - .union(sqlContext.table("testData")) - .union(sqlContext.table("testData")) + spark.table("testData") + .union(spark.table("testData")) + .union(spark.table("testData")) .registerTempTable("testData3x") try { @@ -333,7 +333,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(null, null, 0) :: Nil) } finally { - sqlContext.dropTempTable("testData3x") + spark.catalog.dropTempTable("testData3x") } } @@ -1041,7 +1041,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SET commands semantics using sql()") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -1082,17 +1082,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql(s"SET $nonexistentKey"), Row(nonexistentKey, "<undefined>") ) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SET commands with illegal or inappropriate argument") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("apply schema") { @@ -1110,7 +1110,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = spark.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -1140,7 +1140,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df2 = spark.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -1165,7 +1165,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD3, schema2) + val df3 = spark.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1210,7 +1210,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1226,7 +1226,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3371 Renaming a function expression with group by gives error") { - sqlContext.udf.register("len", (s: String) => s.length) + spark.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1409,7 +1409,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - sqlContext.read.json(data).registerTempTable("records") + spark.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1450,15 +1450,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") - sqlContext.read.json( + spark.read.json( sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1504,7 +1504,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1517,14 +1517,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1628,7 +1628,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1776,7 +1776,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // We don't support creating a temporary table while specifying a database intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1787,7 +1787,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1795,7 +1795,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } @@ -1818,7 +1818,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), @@ -1880,7 +1880,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11303: filter should not be pushed down into sample") { - val df = sqlContext.range(100) + val df = spark.range(100) List(true, false).foreach { withReplacement => val sampled = df.sample(withReplacement, 0.1, 1) val sampledOdd = sampled.filter("id % 2 != 0") @@ -2059,7 +2059,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // Identity udf that tracks the number of times it is called. val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { + spark.udf.register("testUdf", (x: Int) => { countAcc.++=(1) x }) @@ -2093,9 +2093,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "false") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index ddab918629..b489b74fec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 6fb1aca769..1ab562f873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -290,7 +290,7 @@ trait StreamTest extends QueryTest with Timeouts { verify(currentStream == null, "stream already running") lastStream = currentStream currentStream = - sqlContext + spark .streams .startQuery( StreamExecution.nextName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 6809f26968..c7b95c2683 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -281,7 +281,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val df = sqlContext.range(1) + val df = spark.range(1) checkAnswer( df.select(format_number(lit(5L), 4)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ec950332c5..427f24a9f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - sqlContext.dropTempTable("tmp_table") + spark.catalog.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - sqlContext.dropTempTable("test_table") + spark.catalog.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -88,22 +88,22 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - sqlContext.udf.register("strLenScala", (_: String).length) + spark.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - sqlContext.udf.register("random0", () => { Math.random()}) + spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + spark.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -115,7 +115,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -134,7 +134,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -151,10 +151,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) - sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) - sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) - sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -173,7 +173,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + spark.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -182,27 +182,27 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + spark.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - sqlContext.udf.register("intExpected", (x: Int) => x) + spark.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } test("udf in different types") { - sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) - sqlContext.udf.register("decimalDataFunc", + spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + spark.udf.register("decimalDataFunc", (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) - sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) - sqlContext.udf.register("arrayDataFunc", + spark.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + spark.udf.register("arrayDataFunc", (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) - sqlContext.udf.register("mapDataFunc", + spark.udf.register("mapDataFunc", (data: scala.collection.Map[Int, String]) => { data }) - sqlContext.udf.register("complexDataFunc", + spark.udf.register("complexDataFunc", (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) checkAnswer( @@ -235,7 +235,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { - val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) // Without the fix, this will fail because we fail to cast data type of b to string // because myUDF does not know its input data type. With the fix, this query should not diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a49aaa8b73..3057e016c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -94,7 +94,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -106,7 +106,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -118,7 +118,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.repartition(1).write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -146,7 +146,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -167,7 +167,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(stringRDD) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 01d485ce2d..70a00a43f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -251,7 +250,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } def withSQLContext( - f: SQLContext => Unit, + f: SparkSession => Unit, targetNumPostShufflePartitions: Int, minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = @@ -272,9 +271,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") } - val sparkContext = new SparkContext(sparkConf) - val sqlContext = new TestSQLContext(sparkContext) - try f(sqlContext) finally sparkContext.stop() + + val spark = SparkSession.builder + .config(sparkConf) + .getOrCreate() + try f(spark) finally spark.stop() } Seq(Some(3), None).foreach { minNumPostShufflePartitions => @@ -284,9 +285,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 20 as key", "id as value") val agg = df.groupBy("key").count @@ -294,7 +295,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. checkAnswer( agg, - sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -325,13 +326,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: join operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -339,10 +340,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "id as value") - .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) checkAnswer( join, expectedAnswer.collect()) @@ -376,16 +377,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") .groupBy("key2") @@ -396,7 +397,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 500) .selectExpr("id", "2 as cnt") checkAnswer( @@ -428,16 +429,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -448,7 +449,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "2 as cnt", "id as value") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index ba16810cee..36cde3233d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -50,7 +50,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) @@ -75,7 +75,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("ShuffleExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3b2911d056..d2e1ea12fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.sessionState.planner + val planner = spark.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -78,7 +78,7 @@ class PlannerSuite extends SharedSQLContext { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + spark.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -136,7 +136,7 @@ class PlannerSuite extends SharedSQLContext { sql("CACHE TABLE tiny") val a = testData.as("a") - val b = sqlContext.table("tiny").as("b") + val b = spark.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoinExec => join } @@ -145,7 +145,7 @@ class PlannerSuite extends SharedSQLContext { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") - sqlContext.clearCache() + spark.catalog.clearCache() } } } @@ -154,8 +154,8 @@ class PlannerSuite extends SharedSQLContext { withTempPath { file => val path = file.getCanonicalPath testData.write.parquet(path) - val df = sqlContext.read.parquet(path) - sqlContext.registerDataFrameAsTable(df, "testPushed") + val df = spark.read.parquet(path) + spark.wrapped.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan @@ -295,7 +295,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -315,7 +315,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -333,7 +333,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -353,7 +353,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -376,7 +376,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -392,7 +392,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -408,7 +408,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -425,7 +425,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -444,7 +444,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -464,7 +464,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -493,7 +493,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -510,7 +510,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) + val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index c9f517ca34..ad41111bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -50,16 +50,19 @@ class SQLExecutionSuite extends SparkFunSuite { } test("concurrent query execution with fork-join pool (SPARK-13747)") { - val sc = new SparkContext("local[*]", "test") - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[*]") + .appName("test") + .getOrCreate() + + import spark.implicits._ try { // Should not throw IllegalArgumentException (1 to 100).par.foreach { _ => - sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() + spark.sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } } finally { - sc.stop() + spark.sparkContext.stop() } } @@ -67,8 +70,8 @@ class SQLExecutionSuite extends SparkFunSuite { * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. */ private def testConcurrentQueryExecution(sc: SparkContext): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder.getOrCreate() + import spark.implicits._ // Initialize local properties. This is necessary for the test to pass. sc.getLocalProperties diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 073e0b3f00..d7eae21f9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.test.SQLTestUtils @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def sqlContext: SQLContext + protected def spark: SparkSession /** * Runs the plan and makes sure the answer matches the expected result. @@ -90,9 +90,10 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { - case Some(errorMessage) => fail(errorMessage) - case None => + SparkPlanTest + .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match { + case Some(errorMessage) => fail(errorMessage) + case None => } } @@ -114,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -141,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -162,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -202,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -230,8 +231,8 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - val execution = new QueryExecution(sqlContext.sparkSession, null) { + private def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = { + val execution = new QueryExecution(spark.sparkSession, null) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 233104ae84..ada60f6919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("range/filter should be combined") { - val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { - val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -44,7 +44,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy("id").count().orderBy("id") val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -53,10 +53,10 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastHashJoin should be included in WholeStageCodegen") { - val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) val schema = new StructType().add("k", IntegerType).add("v", StringType) - val smallDF = sqlContext.createDataFrame(rdd, schema) - val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + val smallDF = spark.createDataFrame(rdd, schema) + val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -64,7 +64,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Sort should be included in WholeStageCodegen") { - val df = sqlContext.range(3, 0, -1).toDF().sort(col("id")) + val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -75,7 +75,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("MapElements should be included in WholeStageCodegen") { import testImplicits._ - val ds = sqlContext.range(10).map(_.toString) + val ds = spark.range(10).map(_.toString) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -84,7 +84,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0) + val ds = spark.range(10).filter(_ % 2 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -93,7 +93,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("back-to-back typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 50c8745a28..88269a6a2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -42,14 +43,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // TODO: Improve this test when we have better statistics sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - sqlContext.cacheTable("sizeTst") + spark.catalog.cacheTable("sizeTst") assert( - sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - sqlContext.conf.autoBroadcastJoinThreshold) + spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan + val plan = spark.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("repeatedData") + spark.catalog.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("nullableRepeatedData") + spark.catalog.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -97,7 +98,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - sqlContext.cacheTable("timestamps") + spark.catalog.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -109,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("withEmptyParts") + spark.catalog.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -178,35 +179,35 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + spark.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan + spark.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - sqlContext.isCached("InMemoryCache_different_data_types"), + spark.catalog.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - sqlContext.table("InMemoryCache_different_data_types").collect()) - sqlContext.dropTempTable("InMemoryCache_different_data_types") + spark.table("InMemoryCache_different_data_types").collect()) + spark.catalog.dropTempTable("InMemoryCache_different_data_types") } test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { - val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + val df = spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") val cached = df.cache() // count triggers the caching action. It should not throw. cached.count() // Make sure, the DataFrame is indeed cached. - assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + assert(spark.cacheManager.lookupCachedData(cached).nonEmpty) // Check result. checkAnswer( cached, - sqlContext.range(1, 100).selectExpr("id % 10 as id") + spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") ) @@ -215,7 +216,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { - val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + val data = spark.range(10).selectExpr("id", "cast(id as string) as s") data.cache() assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9164074a3e..48c798986b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -32,23 +32,24 @@ class PartitionBatchPruningSuite import testImplicits._ - private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) + private lazy val originalInMemoryPartitionPruning = + spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, 10) // Enable in-memory partition pruning - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true) // Enable in-memory table scan accumulators - sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { try { - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, originalColumnBatchSize) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, originalInMemoryPartitionPruning) } finally { super.afterAll() } @@ -63,12 +64,12 @@ class PartitionBatchPruningSuite TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") - sqlContext.cacheTable("pruningData") + spark.catalog.cacheTable("pruningData") } override protected def afterEach(): Unit = { try { - sqlContext.uncacheTable("pruningData") + spark.catalog.uncacheTable("pruningData") } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3586ddf7b6..5fbab2382a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,7 +37,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test - sqlContext.sessionState.catalog.reset() + spark.sessionState.catalog.reset() } finally { super.afterEach() } @@ -66,7 +66,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", sqlContext.conf.warehousePath, Map()), ignoreIfExists = false) + CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + ignoreIfExists = false) } private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { @@ -111,7 +112,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("the qualified path of a database is stored in the catalog") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog withTempDir { tmpDir => val path = tmpDir.toString @@ -274,7 +275,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => val dbNameWithoutBackTicks = cleanIdentifier(dbName) - assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + assert(!spark.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") @@ -334,7 +335,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in default db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", None) createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) @@ -343,7 +344,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create table in a specific db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog createDatabase(catalog, "dbx") val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) @@ -352,7 +353,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", Some("dbx")) val tableIdent2 = TableIdentifier("tab2", Some("dbx")) createDatabase(catalog, "dbx") @@ -444,7 +445,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: set properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -471,7 +472,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: unset properties") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -512,7 +513,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: bucketing is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -523,7 +524,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: skew is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -560,7 +561,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -661,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop table - temporary table") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql( """ |CREATE TEMPORARY TABLE tab1 @@ -686,7 +687,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropTable(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -705,7 +706,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // SQLContext does not support create view. Log an error message, if tab1 does not exists sql("DROP VIEW tab1") - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -726,7 +727,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetLocation(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1") createDatabase(catalog, "dbx") @@ -784,7 +785,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetSerde(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -830,7 +831,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testAddPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") @@ -880,7 +881,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1") val part2 = Map("b" -> "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index ac2af77a6e..52dda8c6ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -281,7 +281,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi )) val fakeRDD = new FileScanRDD( - sqlContext.sparkSession, + spark, (file: PartitionedFile) => Iterator.empty, Seq(partition) ) @@ -399,7 +399,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, "*" * size) } - val df = sqlContext.read + val df = spark.read .format(classOf[TestFileFormat].getName) .load(tempDir.getCanonicalPath) @@ -409,7 +409,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) } - Dataset.ofRows(sqlContext.sparkSession, bucketed) + Dataset.ofRows(spark, bucketed) } else { df } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 297731c70c..89d57653ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -27,7 +27,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { test("sizeInBytes should be the total size of all files") { withTempDir{ dir => dir.delete() - sqlContext.range(1000).write.parquet(dir.toString) + spark.range(1000).write.parquet(dir.toString) // ignore hidden files val allFiles = dir.listFiles(new FilenameFilter { override def accept(dir: File, name: String): Boolean = { @@ -35,7 +35,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { } }) val totalSize = allFiles.map(_.length()).sum - val df = sqlContext.read.parquet(dir.toString) + val df = spark.read.parquet(dir.toString) assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 28e59055fa..b6cdc8cfab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -91,7 +91,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -101,7 +101,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with calling another function to load") { - val cars = sqlContext + val cars = spark .read .option("header", "false") .csv(testFile(carsFile)) @@ -110,7 +110,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with type inference") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -121,7 +121,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test inferring booleans") { - val result = sqlContext.read + val result = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -133,7 +133,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with alternative delimiter and quote") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) .load(testFile(carsAltFile)) @@ -142,7 +142,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("parse unescaped quotes with maxCharsPerColumn") { - val rows = sqlContext.read + val rows = spark.read .format("csv") .option("maxCharsPerColumn", "4") .load(testFile(unescapedQuotesFile)) @@ -154,7 +154,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("bad encoding name") { val exception = intercept[UnsupportedCharsetException] { - sqlContext + spark .read .format("csv") .option("charset", "1-9588-osi") @@ -166,7 +166,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test different encoding") { // scalastyle:off - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsFile8859)}", header "true", @@ -174,12 +174,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) // scalastyle:on - verifyCars(sqlContext.table("carsTable"), withHeader = true) + verifyCars(spark.table("carsTable"), withHeader = true) } test("test aliases sep and encoding for delimiter and charset") { // scalastyle:off - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -192,17 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with tab separated file") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") """.stripMargin.replaceAll("\n", " ")) - verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + verifyCars(spark.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) } test("DDL test parsing decimal type") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, priceTag decimal, @@ -212,11 +212,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) assert( - sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + spark.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) } test("test for DROPMALFORMED parsing mode") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -226,7 +226,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test for FAILFAST parsing mode") { val exception = intercept[SparkException]{ - sqlContext.read + spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() @@ -236,7 +236,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for tokens more than the fields in the schema") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -247,7 +247,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with null quote character") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .option("quote", "") @@ -258,7 +258,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with empty file and known schema") { - val result = sqlContext.read + val result = spark.read .format("csv") .schema(StructType(List(StructField("column", StringType, false)))) .load(testFile(emptyFile)) @@ -268,25 +268,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with empty file") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, grp string) |USING csv |OPTIONS (path "${testFile(emptyFile)}", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + assert(spark.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) } test("DDL test with schema") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, blank string) |USING csv |OPTIONS (path "${testFile(carsFile)}", header "true") """.stripMargin.replaceAll("\n", " ")) - val cars = sqlContext.table("carsTable") + val cars = spark.table("carsTable") verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) assert( cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) @@ -295,7 +295,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -304,7 +304,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .csv(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -316,7 +316,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with quote") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -327,7 +327,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("quote", "\"") .save(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .option("quote", "\"") @@ -338,7 +338,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false")) .load(testFile(commentsFile)) @@ -353,7 +353,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("inferring schema with commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) .load(testFile(commentsFile)) @@ -372,7 +372,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "true", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .load(testFile(datesFile)) @@ -393,7 +393,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "false", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .schema(customSchema) @@ -416,7 +416,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("setting comment to null disables comment support") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "", "header" -> "false")) .load(testFile(disableCommentsFile)) @@ -439,7 +439,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("model", StringType, nullable = false), StructField("comment", StringType, nullable = true), StructField("blank", StringType, nullable = true))) - val cars = sqlContext.read + val cars = spark.read .format("csv") .schema(dataSchema) .options(Map("header" -> "true", "nullValue" -> "null")) @@ -454,7 +454,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -468,7 +468,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -486,7 +486,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ) withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -502,7 +502,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -513,7 +513,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Schema inference correctly identifies the datatype when data is sparse.") { - val df = sqlContext.read + val df = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -525,7 +525,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("old csv data source name works") { - val cars = sqlContext + val cars = spark .read .format("com.databricks.spark.csv") .option("header", "false") @@ -535,7 +535,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("nulls, NaNs and Infinity values can be parsed") { - val numbers = sqlContext + val numbers = spark .read .format("csv") .schema(StructType(List( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 1742df31bb..c31dffedbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -27,16 +27,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowComments off") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowComments on") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowComments", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowComments", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -44,16 +44,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowSingleQuotes off") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowSingleQuotes", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowSingleQuotes on") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -61,16 +61,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowUnquotedFieldNames off") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowUnquotedFieldNames on") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -78,16 +78,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowNumericLeadingZeros on") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getLong(0) == 18) @@ -97,16 +97,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. ignore("allowNonNumericNumbers off") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } ignore("allowNonNumericNumbers on") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) @@ -114,16 +114,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowBackslashEscapingAnyCharacter on") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.schema.last.name == "price") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b1279abd63..63fe4658d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -229,7 +229,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = sqlContext.read.json(jsonNullStruct) + val jsonDF = spark.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -248,7 +248,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val jsonDF = spark.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -276,7 +276,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) + val jsonDF = spark.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -375,7 +375,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) + val jsonDF = spark.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -391,7 +391,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -463,7 +463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -516,7 +516,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) + val jsonDF = spark.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -540,7 +540,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = sqlContext.read.json(arrayElementTypeConflict) + val jsonDF = spark.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -568,7 +568,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = sqlContext.read.json(missingFields) + val jsonDF = spark.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -588,7 +588,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -620,7 +620,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path) + val jsonDF = spark.read.option("primitivesAsString", "true").json(path) val expectedSchema = StructType( StructField("bigInteger", StringType, true) :: @@ -648,7 +648,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset primitivesAsString returns complex fields as strings") { - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1) + val jsonDF = spark.read.option("primitivesAsString", "true").json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -746,7 +746,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") { - val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType) + val jsonDF = spark.read.option("prefersDecimal", "true").json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -777,7 +777,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val mixedIntegerAndDoubleRecords = sparkContext.parallelize( """{"a": 3, "b": 1.1}""" :: s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) @@ -796,7 +796,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer big integers correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .json(bigIntegerRecords) // The value in `a` field will be a double as it does not fit in decimal. For `b` field, @@ -810,7 +810,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer floating-point values correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords) @@ -823,7 +823,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) - val mergedJsonDF = sqlContext.read + val mergedJsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords ++ bigIntegerRecords) @@ -881,7 +881,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = sqlContext.read.schema(schema).json(path) + val jsonDF1 = spark.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -898,7 +898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = spark.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -919,7 +919,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = spark.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -947,7 +947,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = spark.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -973,7 +973,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) + val jsonDF = spark.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -991,7 +991,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) + val jsonDF = spark.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -1014,7 +1014,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = sqlContext.read.json(jsonArray) + val jsonDF = spark.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -1035,7 +1035,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .json(corruptRecords) .collect() @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) val exceptionTwo = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .schema(schema) .json(corruptRecords) @@ -1060,7 +1060,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaTwo = StructType( StructField("a", StringType, true) :: Nil) // `DROPMALFORMED` mode should skip corrupt records - val jsonDFOne = sqlContext.read + val jsonDFOne = spark.read .option("mode", "DROPMALFORMED") .json(corruptRecords) checkAnswer( @@ -1069,7 +1069,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) assert(jsonDFOne.schema === schemaOne) - val jsonDFTwo = sqlContext.read + val jsonDFTwo = spark.read .option("mode", "DROPMALFORMED") .schema(schemaTwo) .json(corruptRecords) @@ -1083,7 +1083,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { withTempTable("jsonTable") { - val jsonDF = sqlContext.read.json(corruptRecords) + val jsonDF = spark.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: @@ -1134,7 +1134,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-13953 Rename the corrupt record field via option") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("columnNameOfCorruptRecord", "_malformed") .json(corruptRecords) val schema = StructType( @@ -1155,7 +1155,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-4068: nulls in arrays") { - val jsonDF = sqlContext.read.json(nullsInArrays) + val jsonDF = spark.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -1201,7 +1201,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = spark.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -1224,7 +1224,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD2, schema2) + val df3 = spark.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -1232,8 +1232,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = sqlContext.read.json(primitiveFieldAndType) - val primTable = sqlContext.read.json(jsonDF.toJSON.rdd) + val jsonDF = spark.read.json(primitiveFieldAndType) + val primTable = spark.read.json(jsonDF.toJSON.rdd) primTable.registerTempTable("primitiveTable") checkAnswer( sql("select * from primitiveTable"), @@ -1245,8 +1245,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = sqlContext.read.json(complexFieldAndType1) - val compTable = sqlContext.read.json(complexJsonDF.toJSON.rdd) + val complexJsonDF = spark.read.json(complexFieldAndType1) + val compTable = spark.read.json(complexJsonDF.toJSON.rdd) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1316,7 +1316,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, @@ -1324,7 +1324,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { options = Map("path" -> path)).resolveRelation() val d2 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, @@ -1345,16 +1345,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempDir { dir => val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val df = spark.read.schema(schemaWithSimpleMap).json(mapType1) val path = dir.getAbsolutePath df.write.mode("overwrite").parquet(path) // order of MapType is not defined - assert(sqlContext.read.parquet(path).count() == 5) + assert(spark.read.parquet(path).count() == 5) - val df2 = sqlContext.read.json(corruptRecords) + val df2 = spark.read.json(corruptRecords) df2.write.mode("overwrite").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df2.collect()) + checkAnswer(spark.read.parquet(path), df2.collect()) } } } @@ -1387,7 +1387,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "col1", "abd") - sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + spark.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( @@ -1447,7 +1447,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(4.75.toFloat, Seq(false, true)), new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = - Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil // Data generated by previous versions. // scalastyle:off @@ -1462,7 +1462,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // scalastyle:on // Generate data for the current version. - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) withTempPath { path => df.write.format("json").mode("overwrite").save(path.getCanonicalPath) @@ -1486,13 +1486,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "Spark 1.4.1", "Spark 1.5.0", "Spark 1.5.0", - "Spark " + sqlContext.sparkContext.version, - "Spark " + sqlContext.sparkContext.version) + "Spark " + spark.sparkContext.version, + "Spark " + spark.sparkContext.version) val expectedResult = col0Values.map { v => Row.fromSeq(Seq(v) ++ constantValues) } checkAnswer( - sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + spark.read.format("json").schema(schema).load(path.getCanonicalPath), expectedResult ) } @@ -1502,16 +1502,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(2) + val df = spark.range(2) df.write.json(path + "/p=1") df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) + assert(spark.read.json(path).count() === 4) val extraOptions = Map( "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName ) - assert(sqlContext.read.options(extraOptions).json(path).count() === 2) + assert(spark.read.options(extraOptions).json(path).count() === 2) } } @@ -1525,12 +1525,12 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { { // We need to make sure we can infer the schema. - val jsonDF = sqlContext.read.json(additionalCorruptRecords) + val jsonDF = spark.read.json(additionalCorruptRecords) assert(jsonDF.schema === schema) } { - val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + val jsonDF = spark.read.schema(schema).json(additionalCorruptRecords) jsonDF.registerTempTable("jsonTable") // In HiveContext, backticks should be used to access columns starting with a underscore. @@ -1563,7 +1563,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StructType( StructField("b", StringType) :: Nil )) :: Nil) - val jsonDF = sqlContext.read.schema(schema).json(path) + val jsonDF = spark.read.schema(schema).json(path) assert(jsonDF.count() == 2) } } @@ -1575,7 +1575,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1585,7 +1585,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .load(jsonDir) @@ -1611,7 +1611,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1622,7 +1622,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .options(extraOptions) .load(jsonDir) @@ -1637,7 +1637,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Casting long as timestamp") { withTempTable("jsonTable") { val schema = (new StructType).add("ts", TimestampType) - val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + val jsonDF = spark.read.schema(schema).json(timestampAsLong) jsonDF.registerTempTable("jsonTable") @@ -1657,8 +1657,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val json = s""" |{"a": [{$nested}], "b": [{$nested}]} """.stripMargin - val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.makeRDD(Seq(json)) + val df = spark.read.json(rdd) assert(df.schema.size === 2) df.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2873c6a881..f4a3336643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession private[json] trait TestJsonData { - protected def sqlContext: SQLContext + protected def spark: SparkSession def primitiveFieldAndType: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def additionalCorruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"dummy":"test"}""" :: """[1,2,3]""" :: """":"test", "a":1}""" :: @@ -197,7 +197,7 @@ private[json] trait TestJsonData { """ ","ian":"test"}""" :: Nil) def emptyRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -206,23 +206,23 @@ private[json] trait TestJsonData { """]""" :: Nil) def timestampAsLong: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"ts":1451732645}""" :: Nil) def arrayAndStructRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a": {"b": 1}}""" :: """{"a": []}""" :: Nil) def floatingValueRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) def bigIntegerRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) - lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) - def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index f98ea8c5ae..6509e04e85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -67,7 +67,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( i % 2 == 0, i, @@ -114,7 +114,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => if (i % 3 == 0) { Row.apply(Seq.fill(7)(null): _*) } else { @@ -155,7 +155,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(i => s"val_$i"), if (i % 3 == 0) null else Seq.tabulate(3)(identity)) @@ -182,7 +182,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) }) } @@ -205,7 +205,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) }) } @@ -221,7 +221,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(n => s"arr_${i + n}"), Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, @@ -267,7 +267,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } } - checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + checkAnswer(spark.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 45cc6810d4..57cd70e191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -38,7 +38,7 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq } protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() val fsPath = new Path(path) val fs = fsPath.getFileSystem(hadoopConf) val parquetFiles = fs.listStatus(fsPath, new PathFilter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 65635e3c06..45fd6a5d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -304,7 +304,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), + spark.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -321,7 +321,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), + spark.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } @@ -339,7 +339,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, // this query will throw an exception since the project from combinedFilter expect // two projection while the - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) assert(df1.filter("a > 1 or b < 2").count() == 2) } @@ -358,7 +358,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // test the generate new projection case // when projects != partitionAndNormalColumnProjs - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) checkAnswer( df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), @@ -381,7 +381,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). - val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") + val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( df, Row(1, "1", null)) @@ -394,7 +394,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex df.write.parquet(pathThree) // We will remove the temporary metadata when writing Parquet file. - val schema = sqlContext.read.parquet(pathThree).schema + val schema = spark.read.parquet(pathThree).schema assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) val pathFour = s"${dir.getCanonicalPath}/table4" @@ -407,7 +407,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "s.c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. - val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1") .selectExpr("s") checkAnswer(dfStruct3, Row(Row(null, 1))) @@ -420,7 +420,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex dfStruct3.write.parquet(pathSix) // We will remove the temporary metadata when writing Parquet file. - val forPathSix = sqlContext.read.parquet(pathSix).schema + val forPathSix = spark.read.parquet(pathSix).schema assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) // sanity test: make sure optional metadata field is not wrongly set. @@ -429,7 +429,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val pathEight = s"${dir.getCanonicalPath}/table8" (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) - val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") checkAnswer( df2, Row(1, "1")) @@ -449,7 +449,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") + val df = spark.read.parquet(path).filter("a = 2") // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. @@ -470,11 +470,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + spark.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + spark.read.parquet(path).where("not (a = 2 and b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) } } @@ -527,19 +527,19 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - val df = sqlContext.read.parquet(path).where("b in (0,2)") + val df = spark.read.parquet(path).where("b in (0,2)") assert(stripSparkFilter(df).count == 3) - val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + val df1 = spark.read.parquet(path).where("not (b in (1))") assert(stripSparkFilter(df1).count == 3) - val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)") assert(stripSparkFilter(df2).count == 2) - val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)") assert(stripSparkFilter(df3).count == 4) - val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + val df4 = spark.read.parquet(path).where("not (a <= 2)") assert(stripSparkFilter(df4).count == 3) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 32fe5ba127..d0107aae5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -113,7 +113,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) readParquetFile(path.toString)(df => { val sparkTypes = df.schema.map(_.dataType) @@ -132,7 +132,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { testStandardAndLegacyModes("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = { - sqlContext + spark .range(1000) // Parquet doesn't allow column names with spaces, have to add an alias here. // Minus 500 here so that negative decimals are also tested. @@ -250,10 +250,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) val errorMessage = intercept[Throwable] { - sqlContext.read.parquet(path.toString).printSchema() + spark.read.parquet(path.toString).printSchema() }.toString assert(errorMessage.contains("Parquet type not supported")) } @@ -271,15 +271,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) - val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + val sparkTypes = spark.read.parquet(path.toString).schema.map(_.dataType) assert(sparkTypes === expectedSparkTypes) } } test("compression codec") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) @@ -296,7 +296,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { compressionCodecFor(path, codec.name()) } } @@ -304,7 +304,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) + checkCompressionCodec( + CompressionCodecName.fromConf(spark.conf.get(SQLConf.PARQUET_COMPRESSION))) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -351,7 +352,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("write metadata") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => val path = new Path(file.toURI.toString) val fs = FileSystem.getLocal(hadoopConf) @@ -433,7 +434,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf, extraMetadata) readParquetFile(path.toString) { df => @@ -455,7 +456,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ) withTempPath { dir => val message = intercept[SparkException] { - sqlContext.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(message === "Intentional exception for testing purposes") } @@ -465,10 +466,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// intercept[Throwable] { - sqlContext.read.parquet("file:///nonexistent") + spark.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.read.parquet("hdfs://nonexistent") + spark.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } @@ -486,14 +487,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val m1 = intercept[SparkException] { - sqlContext.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m1.contains("Intentional exception for testing purposes")) } withTempPath { dir => val m2 = intercept[SparkException] { - val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m2.contains("Intentional exception for testing purposes")) @@ -512,11 +513,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ParquetOutputFormat.ENABLE_DICTIONARY -> "true" ) - val hadoopConf = sqlContext.sessionState.newHadoopConfWithOptions(extraOptions) + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" - sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + spark.range(1 << 16).selectExpr("(id % 4) AS i") .coalesce(1).write.options(extraOptions).mode("overwrite").parquet(path) val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head @@ -531,7 +532,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data: Dataset[String] = sqlContext.range(200).map (i => + val data: Dataset[String] = spark.range(200).map (i => if (i < 150) null else "a" ) @@ -554,7 +555,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-i32.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) } } } @@ -565,7 +566,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-i64.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } } } @@ -576,7 +577,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("dec-in-fixed-len.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } } } @@ -589,7 +590,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { var hash2: Int = 0 (false :: true :: Nil).foreach { v => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) { - val df = sqlContext.read.parquet(dir.getCanonicalPath) + val df = spark.read.parquet(dir.getCanonicalPath) val rows = df.queryExecution.toRdd.map(_.copy()).collect() val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) if (!v) { @@ -607,7 +608,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("VectorizedParquetRecordReader - direct path read") { val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) withTempPath { dir => - sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { val reader = new VectorizedParquetRecordReader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 83b65fb419..9dc56292c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -81,7 +81,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS logParquetSchema(protobufStylePath) checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath), + spark.read.parquet(dir.getCanonicalPath), Seq( Row(Seq(0, 1)), Row(Seq(2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index b4d35be05d..8707e13461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -400,7 +400,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -484,7 +484,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -532,7 +532,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -572,7 +572,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -604,7 +604,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext + spark .read .option("mergeSchema", "true") .format("parquet") @@ -622,7 +622,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: HadoopFsRelation, _, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -636,7 +636,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(spark.read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -676,12 +676,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } } @@ -697,7 +697,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(dir.getCanonicalPath), df) } } @@ -714,7 +714,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } withTempPath { dir => @@ -731,7 +731,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } @@ -746,7 +746,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha .save(tablePath.getCanonicalPath) val twoPartitionsDF = - sqlContext + spark .read .option("basePath", tablePath.getCanonicalPath) .parquet( @@ -756,7 +756,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer(twoPartitionsDF, df.filter("b != 3")) intercept[AssertionError] { - sqlContext + spark .read .parquet( s"${tablePath.getCanonicalPath}/b=1", @@ -829,7 +829,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } } @@ -884,9 +884,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { val path = dir.getCanonicalPath - val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) df.write.partitionBy("b", "c").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) + checkAnswer(spark.read.parquet(path), df) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f1e9726c38..f9f9f80352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -46,24 +46,24 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") // Query appends, don't test with both read modes. withParquetTable(data, "t", false) { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(spark.table("t"), data.map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } @@ -128,9 +128,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = spark.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -139,12 +139,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -163,9 +163,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -181,19 +181,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + spark.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + spark.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -204,10 +204,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.write.parquet(basePath) - val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + val decimal = spark.read.parquet(basePath).first().getDecimal(0) assert(Decimal("67123.45") === Decimal(decimal)) } } @@ -227,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { checkAnswer( - sqlContext.read.option("mergeSchema", "true").parquet(path), + spark.read.option("mergeSchema", "true").parquet(path), Seq( Row(Row(1, 1, null)), Row(Row(2, 2, null)), @@ -240,7 +240,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - same schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -253,7 +253,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } } @@ -261,12 +261,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-11997 parquet with null partition values") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(1, 3) + spark.range(1, 3) .selectExpr("if(id % 2 = 0, null, id) AS n", "id") .write.partitionBy("n").parquet(path) checkAnswer( - sqlContext.read.parquet(path).filter("n is null"), + spark.read.parquet(path).filter("n is null"), Row(2, null)) } } @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -288,7 +288,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(null, null))) } } @@ -296,7 +296,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -311,13 +311,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L, null, null))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -332,7 +332,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, null, null, 3L))) } } @@ -340,7 +340,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -357,13 +357,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -380,7 +380,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 3L))) } } @@ -388,7 +388,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -406,7 +406,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(1L, 2L, null))) } } @@ -415,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") .coalesce(1) @@ -436,7 +436,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(Seq(Row(0, null))))) } } @@ -445,12 +445,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") .coalesce(1) @@ -467,7 +467,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Seq( Row(Row(0, 1, null)), Row(Row(null, 2, 4)))) @@ -478,12 +478,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -492,7 +492,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext df2.write.mode(SaveMode.Append).parquet(path) checkAnswer( - sqlContext + spark .read .option("mergeSchema", "true") .parquet(path) @@ -507,7 +507,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr( """NAMED_STRUCT( @@ -532,7 +532,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(NestedStruct(1, 2L, 3.5D)))) } } @@ -585,9 +585,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) + checkAnswer(spark.read.parquet(path), df) } } @@ -595,11 +595,11 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withSQLConf("spark.sql.codegen.maxFields" -> "100") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // donot return batch, because whole stage codegen is disabled for wide table (>200 columns) - val df2 = sqlContext.read.parquet(path) + val df2 = spark.read.parquet(path) assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty, "Should not return batch") checkAnswer(df2, df) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cef541f044..373d3a3a0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.{Benchmark, Utils} /** @@ -34,12 +34,16 @@ import org.apache.spark.util.{Benchmark, Utils} object ParquetReadBenchmark { val conf = new SparkConf() conf.set("spark.sql.parquet.compression.codec", "snappy") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + + val spark = SparkSession.builder + .master("local[1]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() // Set default configs. Individual cases will change them if necessary. - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() @@ -48,17 +52,17 @@ object ParquetReadBenchmark { } def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(spark.catalog.dropTempTable) } def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -71,18 +75,18 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } sqlBenchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } } @@ -155,20 +159,20 @@ object ParquetReadBenchmark { def intStringScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Int and String Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } } @@ -189,20 +193,20 @@ object ParquetReadBenchmark { def stringDictionaryScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String Dictionary", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } } @@ -221,23 +225,23 @@ object ParquetReadBenchmark { def partitionTableScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select id % 2 as p, cast(id as INT) as id from t1") .write.partitionBy("p").parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Partitioned Table", values) benchmark.addCase("Read data column") { iter => - sqlContext.sql("select sum(id) from tempTable").collect + spark.sql("select sum(id) from tempTable").collect } benchmark.addCase("Read partition column") { iter => - sqlContext.sql("select sum(p) from tempTable").collect + spark.sql("select sum(p) from tempTable").collect } benchmark.addCase("Read both columns") { iter => - sqlContext.sql("select sum(p), sum(id) from tempTable").collect + spark.sql("select sum(p), sum(id) from tempTable").collect } /* @@ -256,16 +260,16 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + spark.range(values).registerTempTable("t1") + spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String with Nulls Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + spark.sql("select sum(length(c2)) from tempTable where c1 is " + "not NULL and c2 is not NULL").collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 90e3d50714..c43b142de2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -453,11 +453,11 @@ class ParquetSchemaSuite extends ParquetSchemaTest { test("schema merging failure error message") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema of file")) @@ -466,13 +466,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // test for second merging (after read Parquet schema in parallel done) withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") - sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + spark.sparkContext.conf.set("spark.default.parallelism", "20") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index e8c524e9e5..b5fc51603e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -52,7 +52,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (true :: false :: Nil).foreach { vectorized => if (!vectorized || testVectorized) { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - f(sqlContext.read.parquet(path.toString)) + f(spark.read.parquet(path.toString)) } } } @@ -66,7 +66,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + spark.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -90,14 +90,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { withParquetDataFrame(data, testVectorized) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + spark.wrapped.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( @@ -173,6 +173,6 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def readResourceParquetFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) + spark.read.parquet(url.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 88a3d878f9..ff5706999a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -32,7 +32,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar |${readParquetSchema(parquetFilePath.toString)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + checkAnswer(spark.read.parquet(parquetFilePath.toString), (0 until 10).map { i => val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") val nonNullablePrimitiveValues = Seq( @@ -139,7 +139,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar logParquetSchema(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(Seq(Seq(0, 1), Seq(2, 3))), Row(Seq(Seq(4, 5), Seq(6, 7))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala index fd56265297..08b7eb3cf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.util.Benchmark @@ -36,9 +36,10 @@ object TPCDSBenchmark { conf.set("spark.driver.memory", "3g") conf.set("spark.executor.memory", "3g") conf.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + conf.setMaster("local[1]") + conf.setAppName("test-sql-context") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession.builder.config(conf).getOrCreate() // These queries a subset of the TPCDS benchmark queries and are taken from // https://github.com/databricks/spark-sql-perf/blob/master/src/main/scala/com/databricks/spark/ @@ -1186,8 +1187,8 @@ object TPCDSBenchmark { def setupTables(dataLocation: String): Map[String, Long] = { tables.map { tableName => - sqlContext.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) - tableName -> sqlContext.table(tableName).count() + spark.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) + tableName -> spark.table(tableName).count() }.toMap } @@ -1195,18 +1196,18 @@ object TPCDSBenchmark { require(dataLocation.nonEmpty, "please modify the value of dataLocation to point to your local TPCDS data") val tableSizes = setupTables(dataLocation) - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") tpcds.filter(q => q._1 != "").foreach { case (name: String, query: String) => - val numRows = sqlContext.sql(query).queryExecution.logical.map { + val numRows = spark.sql(query).queryExecution.logical.map { case ur@UnresolvedRelation(t: TableIdentifier, _) => tableSizes.getOrElse(t.table, throw new RuntimeException(s"${t.table} not found.")) case _ => 0L }.sum val benchmark = new Benchmark("TPCDS Snappy (scale = 5)", numRows, 5) benchmark.addCase(name) { i => - sqlContext.sql(query).collect() + spark.sql(query).collect() } benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 923c0b350e..f61fce5d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -33,20 +33,20 @@ import org.apache.spark.util.Utils class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { - verifyFrame(sqlContext.read.format("text").load(testFile)) + verifyFrame(spark.read.format("text").load(testFile)) } test("SQLContext.read.text() API") { - verifyFrame(sqlContext.read.text(testFile).toDF()) + verifyFrame(spark.read.text(testFile).toDF()) } test("SPARK-12562 verify write.text() can handle column name beyond `value`") { - val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + val df = spark.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() df.write.text(tempFile.getCanonicalPath) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF()) + verifyFrame(spark.read.text(tempFile.getCanonicalPath).toDF()) Utils.deleteRecursively(tempFile) } @@ -55,18 +55,18 @@ class TextSuite extends QueryTest with SharedSQLContext { val tempFile = Utils.createTempDir() tempFile.delete() - val df = sqlContext.range(2) + val df = spark.range(2) intercept[AnalysisException] { df.write.text(tempFile.getCanonicalPath) } intercept[AnalysisException] { - sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + spark.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) } } test("SPARK-13503 Support to specify the option for compression codec for TEXT") { - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") extensionNameMap.foreach { case (codecName, extension) => @@ -75,7 +75,7 @@ class TextSuite extends QueryTest with SharedSQLContext { testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension"))) - verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + verifyFrame(spark.read.text(tempDirPath).toDF()) } val errMsg = intercept[IllegalArgumentException] { @@ -95,14 +95,14 @@ class TextSuite extends QueryTest with SharedSQLContext { "mapreduce.map.output.compress.codec" -> classOf[GzipCodec].getName ) withTempDir { dir => - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val tempDir = Utils.createTempDir() val tempDirPath = tempDir.getAbsolutePath testDf.write.option("compression", "none") .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) - verifyFrame(sqlContext.read.options(extraOptions).text(tempDirPath).toDF()) + verifyFrame(spark.read.options(extraOptions).text(tempDirPath).toDF()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8aa0114d98..4fc52c99fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -33,7 +33,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index b9df43d049..730ec43556 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -34,7 +34,7 @@ import org.apache.spark.sql.functions._ * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - protected var sqlContext: SQLContext = null + protected var spark: SparkSession = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -45,26 +45,26 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setMaster("local-cluster[2,1,1024]") .setAppName("testing") val sc = new SparkContext(conf) - sqlContext = new SQLContext(sc) + spark = SparkSession.builder.getOrCreate() } override def afterAll(): Unit = { - sqlContext.sparkContext.stop() - sqlContext = null + spark.stop() + spark = null } /** * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { - val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) val plan = - EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) + EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2a4a3690f2..7caeb3be54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -32,7 +32,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder - private lazy val myUpperCaseData = sqlContext.createDataFrame( + private lazy val myUpperCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), @@ -43,7 +43,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = sqlContext.createDataFrame( + private lazy val myLowerCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), @@ -99,7 +99,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) + EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } def makeShuffledHashJoin( @@ -113,7 +113,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } def makeSortMergeJoin( @@ -124,7 +124,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c26cb8483e..001feb0f2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = sqlContext.createDataFrame( + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), @@ -42,7 +42,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = sqlContext.createDataFrame( + private lazy val right = spark.createDataFrame( sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -82,7 +82,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( ShuffledHashJoinExec( leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), @@ -115,7 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d41e88a0aa..1b82769428 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,21 +71,21 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet withSQLConf("spark.sql.codegen.wholeStage" -> "false") { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) @@ -128,7 +128,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) // TODO: update metrics in generated operators - val ds = sqlContext.range(10).filter('id < 5) + val ds = spark.range(10).filter('id < 5) testSparkPlanMetrics(ds.toDF(), 1, Map.empty) } @@ -157,7 +157,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = sqlContext.range(10).sort('id) + val ds = spark.range(10).sort('id) testSparkPlanMetrics(ds.toDF(), 2, Map.empty) } @@ -169,7 +169,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -187,7 +187,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -195,7 +195,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "number of output rows" -> 8L))) ) - val df2 = sqlContext.sql( + val df2 = spark.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -241,7 +241,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") testSparkPlanMetrics(df, 3, Map( @@ -269,7 +269,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( 0L -> ("CartesianProduct", Map( @@ -280,19 +280,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 7b413dda1e..a7b2cfe7d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -217,7 +217,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { withFileStreamSinkLog { sinkLog => - val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) def listBatchFiles(): Set[String] = { fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => @@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(spark, file.getCanonicalPath) f(sinkLog) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 5f92c5bb9b..ef2b479a56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -59,63 +59,63 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) // Adding the same batch does nothing metadataLog.add(1, "batch1-duplicated") assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - sqlContext.conf.setConfString( + spark.conf.set( s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) } } test("HDFSMetadataLog: restart") { withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog2 = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.get(1) === Some("batch1")) assert(metadataLog2.getLatest() === Some(1 -> "batch1")) - assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog2.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } @@ -127,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { new Thread() { override def run(): Unit = waiter { val metadataLog = - new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + new HDFSMetadataLog[String](spark, temp.getAbsolutePath) try { var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) nextBatchId += 1 @@ -146,9 +146,10 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } waiter.await(timeout(10.seconds), dismissals(10)) - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) - assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + assert( + metadataLog.get(None, Some(maxBatchId)) === (0 to maxBatchId).map(i => (i, i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6be94eb24f..4fa1754253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -27,10 +27,11 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -54,19 +55,18 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - withSpark(new SparkContext(sparkConf)) { sc => - val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = - makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -79,30 +79,30 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( - sc: SparkContext, + spark: SparkSession, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + implicit val sqlContext = spark.wrapped + makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) } // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21)) } } test("usage with iterators - only gets and only puts") { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 @@ -130,15 +130,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } - val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) - val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) - val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -149,8 +149,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") @@ -159,7 +159,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) @@ -178,16 +178,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { - withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContext = new SQLContext(sc) + + withSparkSession( + SparkSession.builder + .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) + .getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 67e44849ca..9eff42ab2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -98,7 +98,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } } - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -239,7 +239,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -269,7 +269,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -310,7 +310,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -340,16 +340,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size - sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val previousStageNumber = spark.listener.stageIdToStageMetrics.size + spark.sparkContext.parallelize(1 to 10).foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } test("SPARK-13055: history listener only tracks SQL metrics") { @@ -401,8 +401,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { val sc = new SparkContext(conf) try { SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = new SQLContext(sc) + import spark.implicits._ // Run 100 successful executions and 100 failed executions. // Each execution only has one job and one stage. for (i <- 0 until 100) { @@ -418,12 +418,12 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) + assert(spark.listener.getCompletedExecutions.size <= 50) + assert(spark.listener.getFailedExecutions.size <= 50) // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + assert(spark.listener.executionIdToData.size <= 100) + assert(spark.listener.jobIdToExecutionId.size <= 100) + assert(spark.listener.stageIdToStageMetrics.size <= 100) } finally { sc.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 73c2076a30..56f848b9a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -37,13 +37,12 @@ class CatalogSuite with BeforeAndAfterEach with SharedSQLContext { - private def sparkSession: SparkSession = sqlContext.sparkSession - private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog + private def sessionCatalog: SessionCatalog = spark.sessionState.catalog private val utils = new CatalogTestUtils { override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = sparkSession.sharedState.externalCatalog + override def newEmptyCatalog(): ExternalCatalog = spark.sharedState.externalCatalog } private def createDatabase(name: String): Unit = { @@ -87,8 +86,8 @@ class CatalogSuite private def testListColumns(tableName: String, dbName: Option[String]): Unit = { val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, dbName)) val columns = dbName - .map { db => sparkSession.catalog.listColumns(db, tableName) } - .getOrElse { sparkSession.catalog.listColumns(tableName) } + .map { db => spark.catalog.listColumns(db, tableName) } + .getOrElse { spark.catalog.listColumns(tableName) } assume(tableMetadata.schema.nonEmpty, "bad test") assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test") assume(tableMetadata.bucketColumnNames.nonEmpty, "bad test") @@ -108,85 +107,85 @@ class CatalogSuite } test("current database") { - assert(sparkSession.catalog.currentDatabase == "default") + assert(spark.catalog.currentDatabase == "default") assert(sessionCatalog.getCurrentDatabase == "default") createDatabase("my_db") - sparkSession.catalog.setCurrentDatabase("my_db") - assert(sparkSession.catalog.currentDatabase == "my_db") + spark.catalog.setCurrentDatabase("my_db") + assert(spark.catalog.currentDatabase == "my_db") assert(sessionCatalog.getCurrentDatabase == "my_db") val e = intercept[AnalysisException] { - sparkSession.catalog.setCurrentDatabase("unknown_db") + spark.catalog.setCurrentDatabase("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list databases") { - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) createDatabase("my_db1") createDatabase("my_db2") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db1", "my_db2")) dropDatabase("my_db1") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db2")) } test("list tables") { - assert(sparkSession.catalog.listTables().collect().isEmpty) + assert(spark.catalog.listTables().collect().isEmpty) createTable("my_table1") createTable("my_table2") createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table1", "my_table2", "my_temp_table")) dropTable("my_table1") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) } test("list tables with database") { - assert(sparkSession.catalog.listTables("default").collect().isEmpty) + assert(spark.catalog.listTables("default").collect().isEmpty) createDatabase("my_db1") createDatabase("my_db2") createTable("my_table1", Some("my_db1")) createTable("my_table2", Some("my_db2")) createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_table1", "my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_table1", Some("my_db1")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2")) val e = intercept[AnalysisException] { - sparkSession.catalog.listTables("unknown_db") + spark.catalog.listTables("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list functions") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions().collect().map(_.name).toSet)) + spark.catalog.listFunctions().collect().map(_.name).toSet)) createFunction("my_func1") createFunction("my_func2") createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) dropFunction("my_func1") dropTempFunction("my_temp_func") - val funcNames2 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(!funcNames2.contains("my_temp_func")) @@ -194,14 +193,14 @@ class CatalogSuite test("list functions with database") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions("default").collect().map(_.name).toSet)) + spark.catalog.listFunctions("default").collect().map(_.name).toSet)) createDatabase("my_db1") createDatabase("my_db2") createFunction("my_func1", Some("my_db1")) createFunction("my_func2", Some("my_db2")) createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2 = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(!funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) @@ -210,14 +209,14 @@ class CatalogSuite assert(funcNames2.contains("my_temp_func")) dropFunction("my_func1", Some("my_db1")) dropTempFunction("my_temp_func") - val funcNames1b = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2b = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2b = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(!funcNames1b.contains("my_func1")) assert(!funcNames1b.contains("my_temp_func")) assert(funcNames2b.contains("my_func2")) assert(!funcNames2b.contains("my_temp_func")) val e = intercept[AnalysisException] { - sparkSession.catalog.listFunctions("unknown_db") + spark.catalog.listFunctions("unknown_db") } assert(e.getMessage.contains("unknown_db")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index b87f482941..7ead97bbf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -33,61 +33,61 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("programmatic ways of basic setting and getting") { // Set a conf first. - sqlContext.setConf(testKey, testVal) + spark.conf.set(testKey, testVal) // Clear the conf. - sqlContext.conf.clear() + spark.wrapped.conf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. - assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + assert(spark.conf.getAll === TestSQLContext.overrideConfs) - sqlContext.setConf(testKey, testVal) - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + spark.conf.set(testKey, testVal) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("parse SQL set commands") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() sql(s"set $testKey=$testVal") - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(sqlContext.getConf("some.property", "0") === "20") + assert(spark.conf.get("some.property", "0") === "20") sql("set some.property = 40") - assert(sqlContext.getConf("some.property", "0") === "40") + assert(spark.conf.get("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(sqlContext.getConf(key, "0") === vs) + assert(spark.conf.get(key, "0") === vs) sql(s"set $key=") - assert(sqlContext.getConf(key, "0") === "") + assert(spark.conf.get(key, "0") === "") - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("deprecated property") { - sqlContext.conf.clear() - val original = sqlContext.conf.numShufflePartitions + spark.wrapped.conf.clear() + val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) try{ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(sqlContext.conf.numShufflePartitions === 10) + assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") } } test("invalid conf value") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } @@ -95,35 +95,35 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") - assert(sqlContext.conf.targetPostShuffleInputSize === 100) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") - assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1024) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") - assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1048576) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") - assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1073741824) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") - assert(sqlContext.conf.targetPostShuffleInputSize === -1) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === -1) // Test overflow exception intercept[IllegalArgumentException] { // This value exceeds Long.MaxValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") } intercept[IllegalArgumentException] { // This value less than Long.MinValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SparkSession can access configs set in SparkConf") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 47a1017caa..44d1b9ddda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -337,39 +337,39 @@ class JDBCSuite extends SparkFunSuite } test("Basic API") { - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } test("Partitioning on column that might have null values.") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) .collect().length === 4) assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) .collect().length === 4) // partitioning on a nullable quoted column assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) .collect().length === 4) } @@ -429,9 +429,9 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -439,8 +439,8 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types in cache") { - val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -448,7 +448,7 @@ class JDBCSuite extends SparkFunSuite } test("test types for null value") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -495,7 +495,7 @@ class JDBCSuite extends SparkFunSuite test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -629,7 +629,7 @@ class JDBCSuite extends SparkFunSuite // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); val date = java.sql.Date.valueOf("1995-01-01") - val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(0).getAs[java.sql.Timestamp](2) @@ -639,7 +639,7 @@ class JDBCSuite extends SparkFunSuite test("test credentials in the properties are not in plan output") { val df = sql("SELECT * FROM parts") val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } // test the JdbcRelation toString output @@ -649,9 +649,9 @@ class JDBCSuite extends SparkFunSuite } test("test credentials in the connection url are not in the plan output") { - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee66931..48fa5f9822 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -88,50 +88,50 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) assert( - 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -141,14 +141,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 92061133cd..754aa32a30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -24,7 +24,7 @@ private[sql] abstract class DataSourceTest extends QueryTest { // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(sqlContext.sparkContext) + val ctx = new SQLContext(spark.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a9b1970a7c..a2decadbe0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = sqlContext.range(100).select($"id", lit(1).as("data")) + val df = spark.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = sqlContext.range(100) + val base = spark.range(100) val df = base.union(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -58,7 +58,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => val path = f.getAbsolutePath Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) - assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 3d69c8a187..a743cdde40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -41,13 +41,13 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with override val streamingTimeout = 20.seconds before { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } after { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } testQuietly("listing") { @@ -57,26 +57,26 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with withQueriesOn(ds1, ds2, ds3) { queries => require(queries.size === 3) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) val (q1, q2, q3) = (queries(0), queries(1), queries(2)) - assert(sqlContext.streams.get(q1.name).eq(q1)) - assert(sqlContext.streams.get(q2.name).eq(q2)) - assert(sqlContext.streams.get(q3.name).eq(q3)) + assert(spark.streams.get(q1.name).eq(q1)) + assert(spark.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q3.name).eq(q3)) intercept[IllegalArgumentException] { - sqlContext.streams.get("non-existent-name") + spark.streams.get("non-existent-name") } q1.stop() - assert(sqlContext.streams.active.toSet === Set(q2, q3)) + assert(spark.streams.active.toSet === Set(q2, q3)) val ex1 = withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q1.name) + spark.streams.get(q1.name) } } assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") - assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q2.name).eq(q2)) m2.addData(0) // q2 should terminate with error @@ -86,11 +86,11 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with } withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q2.name).eq(q2) + spark.streams.get(q2.name).eq(q2) } } - assert(sqlContext.streams.active.toSet === Set(q3)) + assert(spark.streams.active.toSet === Set(q3)) } } @@ -98,7 +98,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(5)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking testAwaitAnyTermination(ExpectBlocked) @@ -112,7 +112,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectNotBlocked) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate a query asynchronously with exception and see awaitAnyTermination throws @@ -125,7 +125,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectException[SparkException]) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws @@ -144,7 +144,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(6)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking or non-blocking depending on timeout values testAwaitAnyTermination( @@ -173,7 +173,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination( ExpectBlocked, awaitTimeout = 4 seconds, @@ -196,7 +196,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testBehaviorFor = 4 seconds) // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q3 = stopRandomQueryAsync(2 seconds, withError = true) testAwaitAnyTermination( ExpectNotBlocked, @@ -214,7 +214,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws // the exception - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) testAwaitAnyTermination( @@ -238,7 +238,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val df = ds.toDF val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = sqlContext + query = spark .streams .startQuery( StreamExecution.nextName, @@ -272,10 +272,10 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with def awaitTermFunc(): Unit = { if (awaitTimeout != null && awaitTimeout.toMillis > 0) { - val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + val returnedValue = spark.streams.awaitAnyTermination(awaitTimeout.toMillis) assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") } else { - sqlContext.streams.awaitAnyTermination() + spark.streams.awaitAnyTermination() } } @@ -287,7 +287,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with import scala.concurrent.ExecutionContext.Implicits.global - val activeQueries = sqlContext.streams.active + val activeQueries = spark.streams.active val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) Future { Thread.sleep(stopAfter.toMillis) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index c7b2b99822..cb53b2b1aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -54,18 +54,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { LastOptions.parameters = parameters LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) ("dummySource", fakeSchema) } override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -73,14 +73,14 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.schema = schema LastOptions.mockStreamSourceProvider.createSource( - sqlContext, metadataPath, schema, providerName, parameters) + spark, metadataPath, schema, providerName, parameters) new Source { override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import sqlContext.implicits._ + import spark.implicits._ Seq[Int]().toDS().toDF() } @@ -88,12 +88,12 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } override def createSink( - sqlContext: SQLContext, + spark: SQLContext, parameters: Map[String, String], partitionColumns: Seq[String]): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -107,11 +107,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath after { - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("resolve default source") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream() .write @@ -122,7 +122,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("resolve full class") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test.DefaultSource") .stream() .write @@ -136,7 +136,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) @@ -164,7 +164,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("partitioning") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -204,7 +204,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("stream paths") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .stream("/test") @@ -223,7 +223,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("test different data types for options") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) @@ -253,7 +253,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query with a specific name */ def startQueryWithName(name: String = ""): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -265,7 +265,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query without specifying a name */ def startQueryWithoutName(): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Get the names of active streams */ def activeStreamNames: Set[String] = { - val streams = sqlContext.streams.active + val streams = spark.streams.active val names = streams.map(_.name).toSet assert(streams.length === names.size, s"names of active queries are not unique: $names") names @@ -307,11 +307,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q1.stop() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("trigger") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") @@ -339,11 +339,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val checkpointLocation = newMetadataDir - val df1 = sqlContext.read + val df1 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() - val df2 = sqlContext.read + val df2 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/0", None, "org.apache.spark.sql.streaming.test", Map.empty) verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/1", None, "org.apache.spark.sql.streaming.test", @@ -372,35 +372,35 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath test("check trigger() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) assert(e.getMessage == "trigger() can only be called on continuous queries;") } test("check queryName() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.queryName("queryName")) assert(e.getMessage == "queryName() can only be called on continuous queries;") } test("check startStream() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream()) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check startStream(path) can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream("non_exist_path")) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check mode(SaveMode) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -409,7 +409,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check mode(string) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -418,7 +418,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check bucketBy() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -427,7 +427,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check sortBy() can only be called on non-continuous queries;") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -436,7 +436,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save(path) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -445,7 +445,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -454,7 +454,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check insertInto() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -463,7 +463,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check saveAsTable() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -472,7 +472,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check jdbc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -481,7 +481,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check json() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -490,7 +490,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check parquet() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -499,7 +499,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check orc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -508,7 +508,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check text() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -517,7 +517,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check csv() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index e937fc3e87..6238b74ffa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -40,11 +40,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .select($"id", lit(100).as("data")) val writer = new FileStreamSinkWriter( @@ -56,7 +56,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val files1 = writeRange(0, 10, 2) assert(files1.size === 2, s"unexpected number of files: $files1") checkFilesExist(path, files1, "file not written") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(10, 20, 3) @@ -64,7 +64,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { assert(files2.intersect(files1).isEmpty, "old files returned") checkFilesExist(path, files2, s"New file not written") checkFilesExist(path, files1, s"Old file not found") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) } test("FileStreamSinkWriter - partitioned data") { @@ -72,11 +72,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .flatMap(x => Iterator(x, x, x)).toDF("id") .select($"id", lit(100).as("data1"), lit(1000).as("data2")) @@ -103,7 +103,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 10, files1) val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(0, 20, 3) @@ -114,7 +114,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 20, files2) val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1 ++ answer2) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) } test("FileStreamSink - unpartitioned writing and batch reading") { @@ -139,7 +139,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() } - val outputDf = sqlContext.read.parquet(outputDir).as[Int] + val outputDf = spark.read.parquet(outputDir).as[Int] checkDataset(outputDf, 1, 2, 3) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a62852b512..4b95d65627 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -103,9 +103,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { val reader = if (schema.isDefined) { - sqlContext.read.format(format).schema(schema.get) + spark.read.format(format).schema(schema.get) } else { - sqlContext.read.format(format) + spark.read.format(format) } reader.stream(path) } @@ -149,7 +149,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { format: Option[String], path: Option[String], schema: Option[StructType] = None): StructType = { - val reader = sqlContext.read + val reader = spark.read format.foreach(reader.format) schema.foreach(reader.schema) val df = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 50703e532f..4efb7cf52d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -100,7 +100,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext { } writer.start() - val input = sqlContext.read.format("text").stream(inputDir) + val input = spark.read.format("text").stream(inputDir) def startStream(): ContinuousQuery = { val output = input @@ -150,6 +150,6 @@ class FileStressSuite extends StreamTest with SharedSQLContext { streamThread.join() logError(s"Stream restarted $failures times.") - assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + assert(spark.read.parquet(outputDir).distinct().count() == numRecords) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 74ca3977d6..09c35bbf2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -44,13 +44,13 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3) input.addData(4, 5, 6) query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3, 4, 5, 6) query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bcd3cba55a..6a8b280174 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -94,7 +94,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { .startStream(outputDir.getAbsolutePath) try { query.processAllAvailable() - val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() @@ -103,7 +103,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { } } - val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + val df = spark.read.format(classOf[FakeDefaultSource].getName).stream() assertDF(df) assertDF(df) } @@ -162,13 +162,13 @@ class FakeDefaultSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -190,7 +190,7 @@ class FakeDefaultSource extends StreamSourceProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760b71..03369c5a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} /** * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def sqlContext: SQLContext + protected def spark: SparkSession // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } import internalImplicits._ @@ -39,21 +39,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -65,7 +65,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -73,14 +73,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -92,7 +92,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -104,7 +104,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: @@ -115,7 +115,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -127,7 +127,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -137,7 +137,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -145,7 +145,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -156,13 +156,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -170,7 +170,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -180,7 +180,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -190,7 +190,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -199,13 +199,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -214,13 +214,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -228,7 +228,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -236,7 +236,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -245,7 +245,7 @@ private[sql] trait SQLTestData { self => } protected lazy val courseSales: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( CourseSales("dotNET", 2012, 10000) :: CourseSales("Java", 2012, 20000) :: CourseSales("dotNET", 2012, 5000) :: @@ -259,7 +259,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(spark != null, "attempted to initialize test data before SparkSession.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6d2b95e83a..a49a8c9f2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -50,23 +50,23 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def sparkContext = sqlContext.sparkContext + protected def sparkContext = spark.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = sqlContext.sql _ + protected lazy val sql = spark.sql _ /** * A helper object for importing SQL implicits. * - * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * Note that the alternative of importing `spark.implicits._` is not possible here. * This is because we create the [[SQLContext]] immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } /** @@ -92,12 +92,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -138,9 +138,9 @@ private[sql] trait SQLTestUtils // temp tables that never got created. functions.foreach { case (functionName, isTemporary) => val withTemporary = if (isTemporary) "TEMPORARY" else "" - sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") assert( - !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), s"Function $functionName should have been dropped. But, it still exists.") } } @@ -153,7 +153,7 @@ private[sql] trait SQLTestUtils try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove // temp tables that never got created. - try tableNames.foreach(sqlContext.dropTempTable) catch { + try tableNames.foreach(spark.catalog.dropTempTable) catch { case _: NoSuchTableException => } } @@ -165,7 +165,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + spark.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -176,7 +176,7 @@ private[sql] trait SQLTestUtils protected def withView(viewNames: String*)(f: => Unit): Unit = { try f finally { viewNames.foreach { name => - sqlContext.sql(s"DROP VIEW IF EXISTS $name") + spark.sql(s"DROP VIEW IF EXISTS $name") } } } @@ -191,12 +191,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + spark.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally spark.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -204,8 +204,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sessionState.catalog.setCurrentDatabase(db) - try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") } /** @@ -221,7 +221,7 @@ private[sql] trait SQLTestUtils .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) - sqlContext.createDataFrame(childRDD, schema) + spark.createDataFrame(childRDD, schema) } /** @@ -229,7 +229,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, plan) + Dataset.ofRows(spark, plan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 914c6a5509..620bfa995a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,37 +17,42 @@ package org.apache.spark.sql.test -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.SparkConf +import org.apache.spark.sql.{SparkSession, SQLContext} /** - * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ trait SharedSQLContext extends SQLTestUtils { protected val sparkConf = new SparkConf() /** - * The [[TestSQLContext]] to use for all tests in this suite. + * The [[TestSparkSession]] to use for all tests in this suite. * * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestSQLContext = null + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected implicit def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _spark.wrapped /** - * Initialize the [[TestSQLContext]]. + * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { SQLContext.clearSqlListener() - if (_ctx == null) { - _ctx = new TestSQLContext(sparkConf) + if (_spark == null) { + _spark = new TestSparkSession(sparkConf) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -58,9 +63,9 @@ trait SharedSQLContext extends SQLTestUtils { */ protected override def afterAll(): Unit = { try { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + if (_spark != null) { + _spark.stop() + _spark = null } } finally { super.afterAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 5ef80b9aa3..785e3452a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,44 +18,32 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SQLConf} /** - * A special [[SQLContext]] prepared for testing. + * A special [[SparkSession]] prepared for testing. */ -private[sql] class TestSQLContext( - @transient override val sparkSession: SparkSession, - isRootContext: Boolean) - extends SQLContext(sparkSession, isRootContext) { self => - - def this(sc: SparkContext) { - this(new TestSparkSession(sc), true) - } - +private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) } def this() { - this(new SparkConf) - } - - // Needed for Java tests - def loadTestData(): Unit = { - testData.loadTestData() - } - - private object testData extends SQLTestData { - protected override def sqlContext: SQLContext = self + this { + val conf = new SparkConf() + conf.set("spark.sql.testkey", "true") + + val spark = SparkSession.builder + .master("local[2]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() + spark.sparkContext + } } -} - - -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => - @transient protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { @@ -70,6 +58,14 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } } + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def spark: SparkSession = self + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index 54acd4db3c..8788898fc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -36,11 +36,11 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with import testImplicits._ after { - sqlContext.streams.active.foreach(_.stop()) - assert(sqlContext.streams.active.isEmpty) + spark.streams.active.foreach(_.stop()) + assert(spark.streams.active.isEmpty) assert(addedListeners.isEmpty) // Make sure we don't leak any events to the next test - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) } test("single listener") { @@ -112,17 +112,17 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with val listener1 = new QueryStatusCollector val listener2 = new QueryStatusCollector - sqlContext.streams.addListener(listener1) + spark.streams.addListener(listener1) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === false) - sqlContext.streams.addListener(listener2) + spark.streams.addListener(listener2) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === true) - sqlContext.streams.removeListener(listener1) + spark.streams.removeListener(listener1) assert(isListenerActive(listener1) === false) assert(isListenerActive(listener2) === true) } finally { - addedListeners.foreach(sqlContext.streams.removeListener) + addedListeners.foreach(spark.streams.removeListener) } } @@ -146,18 +146,18 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { try { failAfter(1 minute) { - sqlContext.streams.addListener(listener) + spark.streams.addListener(listener) body } } finally { - sqlContext.streams.removeListener(listener) + spark.streams.removeListener(listener) } } private def addedListeners(): Array[ContinuousQueryListener] = { val listenerBusMethod = PrivateMethod[ContinuousQueryListenerBus]('listenerBus) - val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + val listenerBus = spark.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 8a0578c1ff..3ae5ce610d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -39,7 +39,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += ((funcName, qe, duration)) } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j") df.select("i").collect() @@ -55,7 +55,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) assert(metrics(1)._3 > 0) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("execute callback functions when a DataFrame action failed") { @@ -68,7 +68,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // Only test failed case here, so no need to implement `onSuccess` override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") @@ -82,7 +82,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0)._2.analyzed.isInstanceOf[Project]) assert(metrics(0)._3.getMessage == e.getMessage) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("get numRows metrics by callback") { @@ -99,7 +99,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += metric.value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() df.collect() @@ -111,7 +111,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1) === 1) assert(metrics(2) === 2) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never @@ -131,10 +131,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += bottomAgg.longMetric("dataSize").value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val sparkListener = new SaveInfoListener - sqlContext.sparkContext.addSparkListener(sparkListener) + spark.sparkContext.addSparkListener(sparkListener) val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") df.groupBy("i").count().collect() @@ -157,6 +157,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } } |