aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-10 11:17:47 -0700
committerAndrew Or <andrew@databricks.com>2016-05-10 11:17:47 -0700
commited0b4070fb50054b1ecf66ff6c32458a4967dfd3 (patch)
tree68b3ad1a3ca22f2e0b5966db517c9bc42da3d254 /sql
parentbcfee153b1cacfe617e602f3b72c0877e0bdf1f7 (diff)
downloadspark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.gz
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.tar.bz2
spark-ed0b4070fb50054b1ecf66ff6c32458a4967dfd3.zip
[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12907 from techaddict/SPARK-15037.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java42
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java70
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java89
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java27
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java20
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala216
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala82
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala45
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala80
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala49
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala102
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala76
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala58
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala69
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala74
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala50
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala82
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala56
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala48
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala18
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala18
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala88
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala110
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala18
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala32
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala14
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala38
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala66
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala4
110 files changed, 1651 insertions, 1574 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 189cc3972c..f2ae40e644 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -28,14 +28,13 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType;
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
public class JavaApplySchemaSuite implements Serializable {
- private transient JavaSparkContext javaCtx;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
@Before
public void setUp() {
- SparkContext context = new SparkContext("local[*]", "testing");
- javaCtx = new JavaSparkContext(context);
- sqlContext = new SQLContext(context);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- javaCtx = null;
+ spark.stop();
+ spark = null;
}
public static class Person implements Serializable {
@@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable {
person2.setAge(28);
personList.add(person2);
- JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+ JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
new Function<Person, Row>() {
@Override
public Row call(Person person) throws Exception {
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList();
+ List<Row> actual = spark.sql("SELECT * FROM people").collectAsList();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
@@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable {
person2.setAge(28);
personList.add(person2);
- JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map(
+ JavaRDD<Row> rowRDD = jsc.parallelize(personList).map(
new Function<Person, Row>() {
@Override
public Row call(Person person) {
@@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD()
+ List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD()
.map(new Function<Row, String>() {
@Override
public String call(Row row) {
@@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable {
@Test
public void applySchemaToJSON() {
- JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList(
+ JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList(
"{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " +
"\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " +
"\"boolean\":true, \"null\":null}",
@@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
+ Dataset<Row> df1 = spark.read().json(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
- List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
+ List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+ Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
- List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList();
+ List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 1eb680dc4c..324ebbae38 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -20,12 +20,7 @@ package test.org.apache.spark.sql;
import java.io.Serializable;
import java.net.URISyntaxException;
import java.net.URL;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-import java.util.ArrayList;
+import java.util.*;
import scala.collection.JavaConverters;
import scala.collection.Seq;
@@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.junit.*;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.*;
+import org.apache.spark.util.sketch.BloomFilter;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;
-import org.apache.spark.util.sketch.BloomFilter;
public class JavaDataFrameSuite {
+ private transient TestSparkSession spark;
private transient JavaSparkContext jsc;
- private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
@Test
public void testExecution() {
- Dataset<Row> df = context.table("testData").filter("key = 1");
+ Dataset<Row> df = spark.table("testData").filter("key = 1");
Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
}
@Test
public void testCollectAndTake() {
- Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
@@ -83,7 +77,7 @@ public class JavaDataFrameSuite {
*/
@Test
public void testVarargMethods() {
- Dataset<Row> df = context.table("testData");
+ Dataset<Row> df = spark.table("testData");
df.toDF("key1", "value1");
@@ -112,7 +106,7 @@ public class JavaDataFrameSuite {
df.select(coalesce(col("key")));
// Varargs with mathfunctions
- Dataset<Row> df2 = context.table("testData2");
+ Dataset<Row> df2 = spark.table("testData2");
df2.select(exp("a"), exp("b"));
df2.select(exp(log("a")));
df2.select(pow("a", "a"), pow("b", 2.0));
@@ -126,7 +120,7 @@ public class JavaDataFrameSuite {
@Ignore
public void testShow() {
// This test case is intended ignored, but to make sure it compiles correctly
- Dataset<Row> df = context.table("testData");
+ Dataset<Row> df = spark.table("testData");
df.show();
df.show(1000);
}
@@ -194,7 +188,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List<Bean> data = Arrays.asList(bean);
- Dataset<Row> df = context.createDataFrame(data, Bean.class);
+ Dataset<Row> df = spark.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -202,7 +196,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
+ Dataset<Row> df = spark.createDataFrame(rdd, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -210,7 +204,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
- Dataset<Row> df = context.createDataFrame(rows, schema);
+ Dataset<Row> df = spark.createDataFrame(rows, schema);
List<Row> result = df.collectAsList();
Assert.assertEquals(1, result.size());
}
@@ -239,7 +233,7 @@ public class JavaDataFrameSuite {
@Test
public void testCrosstab() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
@@ -258,7 +252,7 @@ public class JavaDataFrameSuite {
@Test
public void testFrequentItems() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
String[] cols = {"a"};
Dataset<Row> results = df.stat().freqItems(cols, 0.2);
Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
@@ -266,21 +260,21 @@ public class JavaDataFrameSuite {
@Test
public void testCorrelation() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
}
@Test
public void testCovariance() {
- Dataset<Row> df = context.table("testData2");
+ Dataset<Row> df = spark.table("testData2");
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1.0e-6);
}
@Test
public void testSampleBy() {
- Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+ Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
Assert.assertEquals(0, actual.get(0).getLong(0));
@@ -291,7 +285,7 @@ public class JavaDataFrameSuite {
@Test
public void pivot() {
- Dataset<Row> df = context.table("courseSales");
+ Dataset<Row> df = spark.table("courseSales");
List<Row> actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
.agg(sum("earnings")).orderBy("year").collectAsList();
@@ -324,10 +318,10 @@ public class JavaDataFrameSuite {
@Test
public void testGenericLoad() {
- Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt"));
+ Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt"));
Assert.assertEquals(4L, df1.count());
- Dataset<Row> df2 = context.read().format("text").load(
+ Dataset<Row> df2 = spark.read().format("text").load(
getResource("text-suite.txt"),
getResource("text-suite2.txt"));
Assert.assertEquals(5L, df2.count());
@@ -335,10 +329,10 @@ public class JavaDataFrameSuite {
@Test
public void testTextLoad() {
- Dataset<String> ds1 = context.read().text(getResource("text-suite.txt"));
+ Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt"));
Assert.assertEquals(4L, ds1.count());
- Dataset<String> ds2 = context.read().text(
+ Dataset<String> ds2 = spark.read().text(
getResource("text-suite.txt"),
getResource("text-suite2.txt"));
Assert.assertEquals(5L, ds2.count());
@@ -346,7 +340,7 @@ public class JavaDataFrameSuite {
@Test
public void testCountMinSketch() {
- Dataset<Long> df = context.range(1000);
+ Dataset<Long> df = spark.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -371,7 +365,7 @@ public class JavaDataFrameSuite {
@Test
public void testBloomFilter() {
- Dataset<Long> df = context.range(1000);
+ Dataset<Long> df = spark.range(1000);
BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f1b1c22e4a..8354a5bdac 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,46 +23,43 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
-import com.google.common.base.Objects;
-import org.junit.rules.ExpectedException;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;
+import com.google.common.base.Objects;
import org.junit.*;
+import org.junit.rules.ExpectedException;
import org.apache.spark.Accumulator;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.*;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.test.TestSparkSession;
import org.apache.spark.sql.types.StructType;
-
-import static org.apache.spark.sql.functions.*;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaDatasetSuite implements Serializable {
+ private transient TestSparkSession spark;
private transient JavaSparkContext jsc;
- private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
List<String> collected = ds.collectAsList();
Assert.assertEquals(Arrays.asList("hello", "world"), collected);
}
@@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testTake() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
List<String> collected = ds.takeAsList(1);
Assert.assertEquals(Arrays.asList("hello"), collected);
}
@@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testToLocalIterator() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Iterator<String> iter = ds.toLocalIterator();
Assert.assertEquals("hello", iter.next());
Assert.assertEquals("world", iter.next());
@@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Assert.assertEquals("hello", ds.first());
Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
public void testForeach() {
final Accumulator<Integer> accum = jsc.accumulator(0);
List<String> data = Arrays.asList("a", "b", "c");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
ds.foreach(new ForeachFunction<String>() {
@Override
@@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testReduce() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
@@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(
new MapFunction<String, Integer>() {
@Override
@@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
toSet(reduced.collectAsList()));
List<Integer> data2 = Arrays.asList(2, 6, 10);
- Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
+ Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
new MapFunction<Integer, Integer>() {
@Override
@@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSelect() {
List<Integer> data = Arrays.asList(2, 6);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
expr("value + 1"),
@@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSetOperation() {
List<String> data = Arrays.asList("abc", "abc", "xyz");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList()));
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
- Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
+ Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING());
Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testJoin() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
+ Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a");
List<Integer> data2 = Arrays.asList(2, 3, 4);
- Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
+ Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b");
Dataset<Tuple2<Integer, Integer>> joined =
ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable {
public void testTupleEncoder() {
Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
- Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
+ Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
Encoder<Tuple3<Integer, Long, String>> encoder3 =
Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
List<Tuple3<Integer, Long, String>> data3 =
Arrays.asList(new Tuple3<>(1, 2L, "a"));
- Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
+ Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
List<Tuple4<Integer, String, Long, String>> data4 =
Arrays.asList(new Tuple4<>(1, "b", 2L, "a"));
- Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
+ Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4);
Assert.assertEquals(data4, ds4.collectAsList());
Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
@@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true));
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
- context.createDataset(data5, encoder5);
+ spark.createDataset(data5, encoder5);
Assert.assertEquals(data5, ds5.collectAsList());
}
@@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
List<Tuple2<Tuple2<Integer, String>, String>> data =
Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
- Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
+ Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
// test (int, (string, string, long))
@@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L)));
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
- context.createDataset(data2, encoder2);
+ spark.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
// test (int, ((string, long), string))
@@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable {
List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
- context.createDataset(data3, encoder3);
+ spark.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
}
@@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable {
1.7976931348623157E308, new BigDecimal("0.922337203685477589"),
Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE));
Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds =
- context.createDataset(data, encoder);
+ spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable {
Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class);
List<KryoSerializable> data = Arrays.asList(
new KryoSerializable("hello"), new KryoSerializable("world"));
- Dataset<KryoSerializable> ds = context.createDataset(data, encoder);
+ Dataset<KryoSerializable> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable {
Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class);
List<JavaSerializable> data = Arrays.asList(
new JavaSerializable("hello"), new JavaSerializable("world"));
- Dataset<JavaSerializable> ds = context.createDataset(data, encoder);
+ Dataset<JavaSerializable> ds = spark.createDataset(data, encoder);
Assert.assertEquals(data, ds.collectAsList());
}
@Test
public void testRandomSplit() {
List<String> data = Arrays.asList("hello", "world", "from", "spark");
- Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
double[] arraySplit = {1, 2, 3};
List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1);
@@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable {
obj2.setF(Arrays.asList(300L, null, 400L));
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
- Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class));
+ Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds.collectAsList());
NestedJavaBean obj3 = new NestedJavaBean();
obj3.setA(obj1);
List<NestedJavaBean> data2 = Arrays.asList(obj3);
- Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class));
+ Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class));
Assert.assertEquals(data2, ds2.collectAsList());
Row row1 = new GenericRow(new Object[]{
@@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable {
.add("d", createArrayType(StringType))
.add("e", createArrayType(StringType))
.add("f", createArrayType(LongType));
- Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema)
+ Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());
}
@@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable {
obj.setB(new Date(0));
obj.setC(java.math.BigDecimal.valueOf(1));
Dataset<SimpleJavaBean2> ds =
- context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
+ spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
ds.collect();
}
@@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
SmallBean smallBean = new SmallBean();
@@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable {
{
Row row = new GenericRow(new Object[] { null });
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
ds.collect();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 4a78dca7fe..2274912521 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -24,33 +24,30 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
public class JavaUDFSuite implements Serializable {
- private transient JavaSparkContext sc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
@Before
public void setUp() {
- SparkContext _sc = new SparkContext("local[*]", "testing");
- sqlContext = new SQLContext(_sc);
- sc = new JavaSparkContext(_sc);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- sc = null;
+ spark.stop();
+ spark = null;
}
@SuppressWarnings("unchecked")
@@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable {
// sqlContext.registerFunction(
// "stringLengthTest", (String str) -> str.length(), DataType.IntegerType);
- sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() {
+ spark.udf().register("stringLengthTest", new UDF1<String, Integer>() {
@Override
public Integer call(String str) {
return str.length();
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
+ Row result = spark.sql("SELECT stringLengthTest('test')").head();
Assert.assertEquals(4, result.getInt(0));
}
@@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable {
// (String str1, String str2) -> str1.length() + str2.length,
// DataType.IntegerType);
- sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
+ spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() {
@Override
public Integer call(String str1, String str2) {
return str1.length() + str2.length();
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
+ Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
index 7863177093..059c2d9f2c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
@@ -26,36 +26,30 @@ import scala.Tuple2;
import org.junit.After;
import org.junit.Before;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.KeyValueGroupedDataset;
-import org.apache.spark.sql.test.TestSQLContext;
+import org.apache.spark.sql.test.TestSparkSession;
/**
* Common test base shared across this and Java8DatasetAggregatorSuite.
*/
public class JavaDatasetAggregatorSuiteBase implements Serializable {
- protected transient JavaSparkContext jsc;
- protected transient TestSQLContext context;
+ private transient TestSparkSession spark;
@Before
public void setUp() {
// Trigger static initializer of TestData
- SparkContext sc = new SparkContext("local[*]", "testing");
- jsc = new JavaSparkContext(sc);
- context = new TestSQLContext(sc);
- context.loadTestData();
+ spark = new TestSparkSession();
+ spark.loadTestData();
}
@After
public void tearDown() {
- context.sparkContext().stop();
- context = null;
- jsc = null;
+ spark.stop();
+ spark = null;
}
protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
@@ -66,7 +60,7 @@ public class JavaDatasetAggregatorSuiteBase implements Serializable {
Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
List<Tuple2<String, Integer>> data =
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
- Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+ Dataset<Tuple2<String, Integer>> ds = spark.createDataset(data, encoder);
return ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e65158eb0..d0435e4d43 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -19,14 +19,16 @@ package test.org.apache.spark.sql.sources;
import java.io.File;
import java.io.IOException;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
@@ -37,8 +39,8 @@ import org.apache.spark.util.Utils;
public class JavaSaveLoadSuite {
- private transient JavaSparkContext sc;
- private transient SQLContext sqlContext;
+ private transient SparkSession spark;
+ private transient JavaSparkContext jsc;
File path;
Dataset<Row> df;
@@ -52,9 +54,11 @@ public class JavaSaveLoadSuite {
@Before
public void setUp() throws IOException {
- SparkContext _sc = new SparkContext("local[*]", "testing");
- sqlContext = new SQLContext(_sc);
- sc = new JavaSparkContext(_sc);
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
path =
Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
@@ -66,16 +70,15 @@ public class JavaSaveLoadSuite {
for (int i = 0; i < 10; i++) {
jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
}
- JavaRDD<String> rdd = sc.parallelize(jsonObjects);
- df = sqlContext.read().json(rdd);
+ JavaRDD<String> rdd = jsc.parallelize(jsonObjects);
+ df = spark.read().json(rdd);
df.registerTempTable("jsonTable");
}
@After
public void tearDown() {
- sqlContext.sparkContext().stop();
- sqlContext = null;
- sc = null;
+ spark.stop();
+ spark = null;
}
@Test
@@ -83,7 +86,7 @@ public class JavaSaveLoadSuite {
Map<String, String> options = new HashMap<>();
options.put("path", path.toString());
df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
- Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
+ Dataset<Row> loadedDF = spark.read().format("json").options(options).load();
checkAnswer(loadedDF, df.collectAsList());
}
@@ -96,8 +99,8 @@ public class JavaSaveLoadSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+ Dataset<Row> loadedDF = spark.read().format("json").schema(schema).options(options).load();
- checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+ checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList());
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 5ef20267f8..800316cde7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -36,7 +36,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
import testImplicits._
def rddIdOf(tableName: String): Int = {
- val plan = sqlContext.table(tableName).queryExecution.sparkPlan
+ val plan = spark.table(tableName).queryExecution.sparkPlan
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
relation.cachedColumnBuffers.id
@@ -73,41 +73,41 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("cache temp table") {
testData.select('key).registerTempTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0)
- sqlContext.cacheTable("tempTable")
+ spark.catalog.cacheTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- sqlContext.uncacheTable("tempTable")
+ spark.catalog.uncacheTable("tempTable")
}
test("unpersist an uncached table will not raise exception") {
- assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None == spark.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None == spark.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None == spark.cacheManager.lookupCachedData(testData))
testData.persist()
- assert(None != sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None != spark.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None == spark.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == sqlContext.cacheManager.lookupCachedData(testData))
+ assert(None == spark.cacheManager.lookupCachedData(testData))
}
test("cache table as select") {
sql("CACHE TABLE tempTable AS SELECT key FROM testData")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- sqlContext.uncacheTable("tempTable")
+ spark.catalog.uncacheTable("tempTable")
}
test("uncaching temp table") {
testData.select('key).registerTempTable("tempTable1")
testData.select('key).registerTempTable("tempTable2")
- sqlContext.cacheTable("tempTable1")
+ spark.catalog.cacheTable("tempTable1")
assertCached(sql("SELECT COUNT(*) FROM tempTable1"))
assertCached(sql("SELECT COUNT(*) FROM tempTable2"))
// Is this valid?
- sqlContext.uncacheTable("tempTable2")
+ spark.catalog.uncacheTable("tempTable2")
// Should this be cached?
assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0)
@@ -117,101 +117,101 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
val data = "*" * 1000
sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
.registerTempTable("bigData")
- sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
- assert(sqlContext.table("bigData").count() === 200000L)
- sqlContext.table("bigData").unpersist(blocking = true)
+ spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+ assert(spark.table("bigData").count() === 200000L)
+ spark.table("bigData").unpersist(blocking = true)
}
test("calling .cache() should use in-memory columnar caching") {
- sqlContext.table("testData").cache()
- assertCached(sqlContext.table("testData"))
- sqlContext.table("testData").unpersist(blocking = true)
+ spark.table("testData").cache()
+ assertCached(spark.table("testData"))
+ spark.table("testData").unpersist(blocking = true)
}
test("calling .unpersist() should drop in-memory columnar cache") {
- sqlContext.table("testData").cache()
- sqlContext.table("testData").count()
- sqlContext.table("testData").unpersist(blocking = true)
- assertCached(sqlContext.table("testData"), 0)
+ spark.table("testData").cache()
+ spark.table("testData").count()
+ spark.table("testData").unpersist(blocking = true)
+ assertCached(spark.table("testData"), 0)
}
test("isCached") {
- sqlContext.cacheTable("testData")
+ spark.catalog.cacheTable("testData")
- assertCached(sqlContext.table("testData"))
- assert(sqlContext.table("testData").queryExecution.withCachedData match {
+ assertCached(spark.table("testData"))
+ assert(spark.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => true
case _ => false
})
- sqlContext.uncacheTable("testData")
- assert(!sqlContext.isCached("testData"))
- assert(sqlContext.table("testData").queryExecution.withCachedData match {
+ spark.catalog.uncacheTable("testData")
+ assert(!spark.catalog.isCached("testData"))
+ assert(spark.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => false
case _ => true
})
}
test("SPARK-1669: cacheTable should be idempotent") {
- assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
+ assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
- sqlContext.cacheTable("testData")
- assertCached(sqlContext.table("testData"))
+ spark.catalog.cacheTable("testData")
+ assertCached(spark.table("testData"))
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
- sqlContext.table("testData").queryExecution.withCachedData.collect {
+ spark.table("testData").queryExecution.withCachedData.collect {
case r: InMemoryRelation => r
}.size
}
- sqlContext.cacheTable("testData")
+ spark.catalog.cacheTable("testData")
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
- sqlContext.table("testData").queryExecution.withCachedData.collect {
+ spark.table("testData").queryExecution.withCachedData.collect {
case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r
}.size
}
- sqlContext.uncacheTable("testData")
+ spark.catalog.uncacheTable("testData")
}
test("read from cached table and uncache") {
- sqlContext.cacheTable("testData")
- checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
- assertCached(sqlContext.table("testData"))
+ spark.catalog.cacheTable("testData")
+ checkAnswer(spark.table("testData"), testData.collect().toSeq)
+ assertCached(spark.table("testData"))
- sqlContext.uncacheTable("testData")
- checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
- assertCached(sqlContext.table("testData"), 0)
+ spark.catalog.uncacheTable("testData")
+ checkAnswer(spark.table("testData"), testData.collect().toSeq)
+ assertCached(spark.table("testData"), 0)
}
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- sqlContext.uncacheTable("testData")
+ spark.catalog.uncacheTable("testData")
}
}
test("SELECT star from cached table") {
sql("SELECT * FROM testData").registerTempTable("selectStar")
- sqlContext.cacheTable("selectStar")
+ spark.catalog.cacheTable("selectStar")
checkAnswer(
sql("SELECT * FROM selectStar WHERE key = 1"),
Seq(Row(1, "1")))
- sqlContext.uncacheTable("selectStar")
+ spark.catalog.uncacheTable("selectStar")
}
test("Self-join cached") {
val unCachedAnswer =
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
- sqlContext.cacheTable("testData")
+ spark.catalog.cacheTable("testData")
checkAnswer(
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
- sqlContext.uncacheTable("testData")
+ spark.catalog.uncacheTable("testData")
}
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
sql("CACHE TABLE testData")
- assertCached(sqlContext.table("testData"))
+ assertCached(spark.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
"Eagerly cached in-memory table should have already been materialized")
sql("UNCACHE TABLE testData")
- assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached")
+ assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
@@ -228,14 +228,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- assertCached(sqlContext.table("testCacheTable"))
+ assertCached(spark.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- sqlContext.uncacheTable("testCacheTable")
+ spark.catalog.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -243,14 +243,14 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("CACHE TABLE tableName AS SELECT ...") {
sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10")
- assertCached(sqlContext.table("testCacheTable"))
+ assertCached(spark.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- sqlContext.uncacheTable("testCacheTable")
+ spark.catalog.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -258,7 +258,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("CACHE LAZY TABLE tableName") {
sql("CACHE LAZY TABLE testData")
- assertCached(sqlContext.table("testData"))
+ assertCached(spark.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -270,7 +270,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
isMaterialized(rddId),
"Lazily cached in-memory table should have been materialized")
- sqlContext.uncacheTable("testData")
+ spark.catalog.uncacheTable("testData")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -278,7 +278,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("InMemoryRelation statistics") {
sql("CACHE TABLE testData")
- sqlContext.table("testData").queryExecution.withCachedData.collect {
+ spark.table("testData").queryExecution.withCachedData.collect {
case cached: InMemoryRelation =>
val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum
assert(cached.statistics.sizeInBytes === actualSizeInBytes)
@@ -287,62 +287,62 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("Drops temporary table") {
testData.select('key).registerTempTable("t1")
- sqlContext.table("t1")
- sqlContext.dropTempTable("t1")
- intercept[AnalysisException](sqlContext.table("t1"))
+ spark.table("t1")
+ spark.catalog.dropTempTable("t1")
+ intercept[AnalysisException](spark.table("t1"))
}
test("Drops cached temporary table") {
testData.select('key).registerTempTable("t1")
testData.select('key).registerTempTable("t2")
- sqlContext.cacheTable("t1")
+ spark.catalog.cacheTable("t1")
- assert(sqlContext.isCached("t1"))
- assert(sqlContext.isCached("t2"))
+ assert(spark.catalog.isCached("t1"))
+ assert(spark.catalog.isCached("t2"))
- sqlContext.dropTempTable("t1")
- intercept[AnalysisException](sqlContext.table("t1"))
- assert(!sqlContext.isCached("t2"))
+ spark.catalog.dropTempTable("t1")
+ intercept[AnalysisException](spark.table("t1"))
+ assert(!spark.catalog.isCached("t2"))
}
test("Clear all cache") {
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
- sqlContext.clearCache()
- assert(sqlContext.cacheManager.isEmpty)
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
+ spark.catalog.clearCache()
+ assert(spark.cacheManager.isEmpty)
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
sql("Clear CACHE")
- assert(sqlContext.cacheManager.isEmpty)
+ assert(spark.cacheManager.isEmpty)
}
test("Clear accumulators when uncacheTable to prevent memory leaking") {
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
- val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect {
+ val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
}.head
- val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect {
+ val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.batchStats.id
}.head
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
assert(AccumulatorContext.get(accId1).isEmpty)
assert(AccumulatorContext.get(accId2).isEmpty)
@@ -351,7 +351,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") {
sparkContext.parallelize((1, 1) :: (2, 2) :: Nil)
.toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc")
- sqlContext.cacheTable("abc")
+ spark.catalog.cacheTable("abc")
val sparkPlan = sql(
"""select a.key, b.key, c.key from
@@ -374,15 +374,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
table3x.registerTempTable("testData3x")
sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
- sqlContext.cacheTable("orderedTable")
- assertCached(sqlContext.table("orderedTable"))
+ spark.catalog.cacheTable("orderedTable")
+ assertCached(spark.table("orderedTable"))
// Should not have an exchange as the query is already sorted on the group by key.
verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0)
checkAnswer(
sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
- sqlContext.uncacheTable("orderedTable")
- sqlContext.dropTempTable("orderedTable")
+ spark.catalog.uncacheTable("orderedTable")
+ spark.catalog.dropTempTable("orderedTable")
// Set up two tables distributed in the same way. Try this with the data distributed into
// different number of partitions.
@@ -390,8 +390,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
withTempTable("t1", "t2") {
testData.repartition(numPartitions, $"key").registerTempTable("t1")
testData2.repartition(numPartitions, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
// Joining them should result in no exchanges.
verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
@@ -403,8 +403,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
sql("SELECT count(*) FROM testData GROUP BY key"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
}
@@ -412,8 +412,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
withTempTable("t1", "t2") {
testData.repartition(6, $"key").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
@@ -421,16 +421,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
// One side of join is not partitioned in the desired way. Need to shuffle one side.
withTempTable("t1", "t2") {
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(6, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
@@ -438,15 +438,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
withTempTable("t1", "t2") {
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(12, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 1)
@@ -454,8 +454,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
// One side of join is not partitioned in the desired way. Since the number of partitions of
@@ -464,30 +464,30 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
withTempTable("t1", "t2") {
testData.repartition(6, $"value").registerTempTable("t1")
testData2.repartition(3, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
verifyNumExchanges(query, 2)
checkAnswer(
query,
testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
// repartition's column ordering is different from group by column ordering.
// But they use the same set of columns.
withTempTable("t1") {
testData.repartition(6, $"value", $"key").registerTempTable("t1")
- sqlContext.cacheTable("t1")
+ spark.catalog.cacheTable("t1")
val query = sql("SELECT value, key from t1 group by key, value")
verifyNumExchanges(query, 0)
checkAnswer(
query,
testData.distinct().select($"value", $"key"))
- sqlContext.uncacheTable("t1")
+ spark.catalog.uncacheTable("t1")
}
// repartition's column ordering is different from join condition's column ordering.
@@ -499,8 +499,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
df1.repartition(6, $"value", $"key").registerTempTable("t1")
val df2 = testData2.select($"a", $"b".cast("string"))
df2.repartition(6, $"a", $"b").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ spark.catalog.cacheTable("t1")
+ spark.catalog.cacheTable("t2")
val query =
sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b")
@@ -509,8 +509,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
checkAnswer(
query,
df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b"))
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
+ spark.catalog.uncacheTable("t1")
+ spark.catalog.uncacheTable("t2")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 19fe29a202..a5aecca13f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val booleanData = {
- sqlContext.createDataFrame(sparkContext.parallelize(
+ spark.createDataFrame(sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
@@ -287,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("isNaN") {
- val testData = sqlContext.createDataFrame(sparkContext.parallelize(
+ val testData = spark.createDataFrame(sparkContext.parallelize(
Row(Double.NaN, Float.NaN) ::
Row(math.log(-1), math.log(-3).toFloat) ::
Row(null, null) ::
@@ -308,7 +308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("nanvl") {
- val testData = sqlContext.createDataFrame(sparkContext.parallelize(
+ val testData = spark.createDataFrame(sparkContext.parallelize(
Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
StructField("c", DoubleType), StructField("d", DoubleType),
@@ -351,7 +351,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("=!=") {
- val nullData = sqlContext.createDataFrame(sparkContext.parallelize(
+ val nullData = spark.createDataFrame(sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -370,7 +370,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
nullData.filter($"a" <=> $"b"),
Row(1, 1) :: Row(null, null) :: Nil)
- val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize(
+ val nullData2 = spark.createDataFrame(sparkContext.parallelize(
Row("abc") ::
Row(null) ::
Row("xyz") :: Nil),
@@ -596,7 +596,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)
- val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name())
+ val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name())
.head.getString(0)
assert(answer.contains(dir.getCanonicalPath))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 63f4b759a0..8a99866a33 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
)
- val decimalDataWithNulls = sqlContext.sparkContext.parallelize(
+ val decimalDataWithNulls = spark.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, null) ::
DecimalData(2, 1) ::
@@ -114,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(null, null, 113000.0) :: Nil
)
- val df0 = sqlContext.sparkContext.parallelize(Seq(
+ val df0 = spark.sparkContext.parallelize(Seq(
Fact(20151123, 18, 35, "room1", 18.6),
Fact(20151123, 18, 35, "room2", 22.4),
Fact(20151123, 18, 36, "room1", 17.4),
@@ -207,12 +207,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
- sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false)
+ spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false)
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
- sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true)
+ spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true)
}
test("agg without groups") {
@@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
- sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+ spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil)
checkAnswer(
- sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+ spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 0414fa1c91..031e66b57c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
// SPARK-12275: no physical plan for BroadcastHint in some condition
withTempPath { path =>
df1.write.parquet(path.getCanonicalPath)
- val pf1 = sqlContext.read.parquet(path.getCanonicalPath)
+ val pf1 = spark.read.parquet(path.getCanonicalPath)
assert(df1.join(broadcast(pf1)).count() === 4)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index c6d67519b0..fa8fa06907 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -81,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
}
test("pivot max values enforced") {
- sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+ spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1)
intercept[AnalysisException](
courseSales.groupBy("year").pivot("course")
)
- sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+ spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
}
@@ -104,7 +104,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
// pivot with extra columns to trigger optimization
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
.agg(sum($"earnings"))
- val queryExecution = sqlContext.executePlan(df.queryExecution.logical)
+ val queryExecution = spark.executePlan(df.queryExecution.logical)
assert(queryExecution.simpleString.contains("pivotfirst"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0ea7727e45..ab7733b239 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -236,7 +236,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("sampleBy") {
- val df = sqlContext.range(0, 100).select((col("id") % 3).as("key"))
+ val df = spark.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
@@ -247,7 +247,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
// `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in
// `CountMinSketchSuite` in project spark-sketch.
test("countMinSketch") {
- val df = sqlContext.range(1000)
+ val df = spark.range(1000)
val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
assert(sketch1.totalCount() === 1000)
@@ -279,7 +279,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
// This test only verifies some basic requirements, more correctness tests can be found in
// `BloomFilterSuite` in project spark-sketch.
test("Bloom filter") {
- val df = sqlContext.range(1000)
+ val df = spark.range(1000)
val filter1 = df.stat.bloomFilter("id", 1000, 0.03)
assert(filter1.expectedFpp() - 0.03 < 1e-3)
@@ -304,7 +304,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin
// Turn on this test if you want to test the performance of approximate quantiles.
ignore("computing quantiles should not take much longer than describe()") {
- val df = sqlContext.range(5000000L).toDF("col1").cache()
+ val df = spark.range(5000000L).toDF("col1").cache()
def seconds(f: => Any): Double = {
// Do some warmup
logDebug("warmup...")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 80a93ee6d4..f77403c13e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -99,8 +99,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
val schema2 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
- val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
- val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
+ val df1 = spark.createDataFrame(rowRDD1, schema1)
+ val df2 = spark.createDataFrame(rowRDD2, schema2)
checkAnswer(
df1.union(df2).orderBy("label"),
@@ -109,8 +109,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("empty data frame") {
- assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
- assert(sqlContext.emptyDataFrame.count() === 0)
+ assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String])
+ assert(spark.emptyDataFrame.count() === 0)
}
test("head and take") {
@@ -369,7 +369,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
checkAnswer(
- sqlContext.range(2).toDF().limit(2147483638),
+ spark.range(2).toDF().limit(2147483638),
Row(0) :: Row(1) :: Nil
)
}
@@ -672,12 +672,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val parquetDir = new File(dir, "parquet").getCanonicalPath
df.write.parquet(parquetDir)
- val parquetDF = sqlContext.read.parquet(parquetDir)
+ val parquetDF = spark.read.parquet(parquetDir)
assert(parquetDF.inputFiles.nonEmpty)
val jsonDir = new File(dir, "json").getCanonicalPath
df.write.json(jsonDir)
- val jsonDF = sqlContext.read.json(jsonDir)
+ val jsonDF = spark.read.json(jsonDir)
assert(parquetDF.inputFiles.nonEmpty)
val unioned = jsonDF.union(parquetDF).inputFiles.sorted
@@ -801,7 +801,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
- val df = sqlContext.createDataFrame(rowRDD, schema)
+ val df = spark.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
@@ -818,14 +818,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
- val df = sqlContext.read.json(sparkContext.makeRDD(
+ val df = spark.read.json(sparkContext.makeRDD(
"""{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)
- val df2 = sqlContext.read.json(sparkContext.makeRDD(
+ val df2 = spark.read.json(sparkContext.makeRDD(
"""{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
@@ -881,53 +881,53 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-7150 range api") {
// numSlice is greater than length
- val res1 = sqlContext.range(0, 10, 1, 15).select("id")
+ val res1 = spark.range(0, 10, 1, 15).select("id")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
- val res2 = sqlContext.range(3, 15, 3, 2).select("id")
+ val res2 = spark.range(3, 15, 3, 2).select("id")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
- val res3 = sqlContext.range(1, -2).select("id")
+ val res3 = spark.range(1, -2).select("id")
assert(res3.count == 0)
// start is positive, end is negative, step is negative
- val res4 = sqlContext.range(1, -2, -2, 6).select("id")
+ val res4 = spark.range(1, -2, -2, 6).select("id")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
// start, end, step are negative
- val res5 = sqlContext.range(-3, -8, -2, 1).select("id")
+ val res5 = spark.range(-3, -8, -2, 1).select("id")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
// start, end are negative, step is positive
- val res6 = sqlContext.range(-8, -4, 2, 1).select("id")
+ val res6 = spark.range(-8, -4, 2, 1).select("id")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
- val res7 = sqlContext.range(-10, -9, -20, 1).select("id")
+ val res7 = spark.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)
- val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
assert(res8.count == 3)
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
- val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
assert(res9.count == 2)
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
// only end provided as argument
- val res10 = sqlContext.range(10).select("id")
+ val res10 = spark.range(10).select("id")
assert(res10.count == 10)
assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
- val res11 = sqlContext.range(-1).select("id")
+ val res11 = spark.range(-1).select("id")
assert(res11.count == 0)
// using the default slice number
- val res12 = sqlContext.range(3, 15, 3).select("id")
+ val res12 = spark.range(3, 15, 3).select("id")
assert(res12.count == 4)
assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
}
@@ -993,13 +993,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// pass case: parquet table (HadoopFsRelation)
df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
- val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath)
+ val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath)
pdf.registerTempTable("parquet_base")
insertion.write.insertInto("parquet_base")
// pass case: json table (InsertableRelation)
df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
- val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath)
+ val jdf = spark.read.json(tempJsonFile.getCanonicalPath)
jdf.registerTempTable("json_base")
insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
@@ -1019,7 +1019,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into an OneRowRelation
- Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row")
+ Dataset.ofRows(spark, OneRowRelation).registerTempTable("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
@@ -1062,7 +1062,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
- val df = sqlContext.read.json(sparkContext.makeRDD(
+ val df = spark.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
}
@@ -1091,10 +1091,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val dir2 = new File(dir, "dir2").getCanonicalPath
df2.write.format("json").save(dir2)
- checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
+ checkAnswer(spark.read.format("json").load(dir1, dir2),
Row(1, 22) :: Row(2, 23) :: Nil)
- checkAnswer(sqlContext.read.format("json").load(dir1),
+ checkAnswer(spark.read.format("json").load(dir1),
Row(1, 22) :: Nil)
}
}
@@ -1116,7 +1116,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
- val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ val input = spark.read.json(spark.sparkContext.makeRDD(
(1 to 10).map(i => s"""{"id": $i}""")))
val df = input.select($"id", rand(0).as('r))
@@ -1185,7 +1185,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath)
- val df = sqlContext.read.parquet(path.getAbsolutePath)
+ val df = spark.read.parquet(path.getAbsolutePath)
checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a"))
}
}
@@ -1244,7 +1244,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
verifyExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key").count())
- val data = sqlContext.sparkContext.parallelize(
+ val data = spark.sparkContext.parallelize(
(1 to 100).map(i => TestData2(i % 10, i))).toDF()
// Distribute and order by.
@@ -1308,7 +1308,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
withTempPath { path =>
val p = path.getAbsolutePath
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
- checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
+ checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012))
}
}
}
@@ -1317,7 +1317,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1)))
- val df = sqlContext.createDataFrame(
+ val df = spark.createDataFrame(
rdd,
new StructType().add("f1", IntegerType).add("f2", IntegerType),
needsConversion = false).select($"F1", $"f2".as("f2"))
@@ -1344,7 +1344,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
- sqlContext.udf.register("boxedUDF",
+ spark.udf.register("boxedUDF",
(i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
@@ -1393,7 +1393,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("reuse exchange") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") {
- val df = sqlContext.range(100).toDF()
+ val df = spark.range(100).toDF()
val join = df.join(df, "id")
val plan = join.queryExecution.executedPlan
checkAnswer(join, df)
@@ -1415,14 +1415,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("sameResult() on aggregate") {
- val df = sqlContext.range(100)
+ val df = spark.range(100)
val agg1 = df.groupBy().count()
val agg2 = df.groupBy().count()
// two aggregates with different ExprId within them should have same result
assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan))
val agg3 = df.groupBy().sum()
assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan))
- val df2 = sqlContext.range(101)
+ val df2 = spark.range(101)
val agg4 = df2.groupBy().count()
assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan))
}
@@ -1454,24 +1454,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("assertAnalyzed shouldn't replace original stack trace") {
val e = intercept[AnalysisException] {
- sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b)
+ spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b)
}
assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName)
}
test("SPARK-13774: Check error message for non existent path without globbed paths") {
- val e = intercept[AnalysisException] (sqlContext.read.format("csv").
+ val e = intercept[AnalysisException] (spark.read.format("csv").
load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage()
assert(e.startsWith("Path does not exist"))
}
test("SPARK-13774: Check error message for not existent globbed paths") {
- val e = intercept[AnalysisException] (sqlContext.read.format("text").
+ val e = intercept[AnalysisException] (spark.read.format("text").
load( "/xyz/*")).getMessage()
assert(e.startsWith("Path does not exist"))
- val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd).
+ val e1 = intercept[AnalysisException] (spark.read.json("/mnt/*/*-xyz.json").rdd).
getMessage()
assert(e1.startsWith("Path does not exist"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 06584ec21e..a957d5ba25 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -249,14 +249,14 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
try {
f(tableName)
} finally {
- sqlContext.dropTempTable(tableName)
+ spark.catalog.dropTempTable(tableName)
}
}
test("time window in SQL with single string expression") {
withTempTable { table =>
checkAnswer(
- sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""")
+ spark.sql(s"""select window(time, "10 seconds"), value from $table""")
.select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
Seq(
Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
@@ -270,7 +270,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
test("time window in SQL with with two expressions") {
withTempTable { table =>
checkAnswer(
- sqlContext.sql(
+ spark.sql(
s"""select window(time, "10 seconds", 10000000), value from $table""")
.select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
Seq(
@@ -285,7 +285,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B
test("time window in SQL with with three expressions") {
withTempTable { table =>
checkAnswer(
- sqlContext.sql(
+ spark.sql(
s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""")
.select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index 68e99d6a6b..fe6ba83b4c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
.add("b3", FloatType)
.add("b4", DoubleType))
- val df = sqlContext.createDataFrame(data, schema)
+ val df = spark.createDataFrame(data, schema)
assert(df.select("b").first() === Row(struct))
}
@@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
.add("b5b", StringType))
.add("b6", StringType))
- val df = sqlContext.createDataFrame(data, schema)
+ val df = spark.createDataFrame(data, schema)
assert(df.select("b").first() === Row(outerStruct))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index ae9fb80c68..d8e241c62f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
@@ -31,14 +31,14 @@ object DatasetBenchmark {
case class Data(l: Long, s: String)
- def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
- import sqlContext.implicits._
+ def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+ import spark.implicits._
- val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val benchmark = new Benchmark("back-to-back map", numRows)
val func = (d: Data) => Data(d.l + 1, d.s)
- val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
@@ -72,17 +72,17 @@ object DatasetBenchmark {
benchmark
}
- def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
- import sqlContext.implicits._
+ def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+ import spark.implicits._
- val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val benchmark = new Benchmark("back-to-back filter", numRows)
val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
val funcs = 0.until(numChains).map { i =>
(d: Data) => func(d, i)
}
- val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
@@ -130,13 +130,13 @@ object DatasetBenchmark {
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
- def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = {
- import sqlContext.implicits._
+ def aggregate(spark: SparkSession, numRows: Long): Benchmark = {
+ import spark.implicits._
- val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val benchmark = new Benchmark("aggregate", numRows)
- val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD sum") { iter =>
rdd.aggregate(0L)(_ + _.l, _ + _)
}
@@ -157,15 +157,17 @@ object DatasetBenchmark {
}
def main(args: Array[String]): Unit = {
- val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
- val sqlContext = new SQLContext(sparkContext)
+ val spark = SparkSession.builder
+ .master("local[*]")
+ .appName("Dataset benchmark")
+ .getOrCreate()
val numRows = 100000000
val numChains = 10
- val benchmark = backToBackMap(sqlContext, numRows, numChains)
- val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
- val benchmark3 = aggregate(sqlContext, numRows)
+ val benchmark = backToBackMap(spark, numRows, numChains)
+ val benchmark2 = backToBackFilter(spark, numRows, numChains)
+ val benchmark3 = aggregate(spark, numRows)
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 942cc09b6d..8c0906b746 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -39,7 +39,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
// Drop the cache.
cached.unpersist()
- assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
+ assert(spark.cacheManager.lookupCachedData(cached).isEmpty, "The Dataset should not be cached.")
}
test("persist and then rebind right encoder when join 2 datasets") {
@@ -56,9 +56,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
assertCached(joined, 2)
ds1.unpersist()
- assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
+ assert(spark.cacheManager.lookupCachedData(ds1).isEmpty,
+ "The Dataset ds1 should not be cached.")
ds2.unpersist()
- assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
+ assert(spark.cacheManager.lookupCachedData(ds2).isEmpty,
+ "The Dataset ds2 should not be cached.")
}
test("persist and then groupBy columns asKey, map") {
@@ -73,8 +75,9 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
assertCached(agged.filter(_._1 == "b"))
ds.unpersist()
- assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
+ assert(spark.cacheManager.lookupCachedData(ds).isEmpty, "The Dataset ds should not be cached.")
agged.unpersist()
- assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
+ assert(spark.cacheManager.lookupCachedData(agged).isEmpty,
+ "The Dataset agged should not be cached.")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3cb4e52c6d..3c8c862c22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -46,12 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
test("range") {
- assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55)
- assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
- assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55)
- assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
- assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55)
- assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+ assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55)
+ assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+ assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55)
+ assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
+ assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55)
+ assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55)
}
test("SPARK-12404: Datatype Helper Serializability") {
@@ -472,7 +472,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-14696: implicit encoders for boxed types") {
- assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L)
+ assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L)
}
test("SPARK-11894: Incorrect results are returned when using null") {
@@ -510,8 +510,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
))
def buildDataset(rows: Row*): Dataset[NestedStruct] = {
- val rowRDD = sqlContext.sparkContext.parallelize(rows)
- sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+ val rowRDD = spark.sparkContext.parallelize(rows)
+ spark.createDataFrame(rowRDD, schema).as[NestedStruct]
}
checkDataset(
@@ -626,7 +626,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-14554: Dataset.map may generate wrong java code for wide table") {
- val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
+ val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
// Make sure the generated code for this plan can compile and execute.
checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*)
}
@@ -654,7 +654,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
dataset.join(actual, dataset("user") === actual("id")).collect()
}
- test("SPARK-15097: implicits on dataset's sqlContext can be imported") {
+ test("SPARK-15097: implicits on dataset's spark can be imported") {
val dataset = Seq(1, 2, 3).toDS()
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}
@@ -735,10 +735,10 @@ object JavaData {
def apply(a: Int): JavaData = new JavaData(a)
}
-/** Used to test importing dataset.sqlContext.implicits._ */
+/** Used to test importing dataset.spark.implicits._ */
object DatasetTransform {
def addOne(ds: Dataset[Int]): Dataset[Int] = {
- import ds.sqlContext.implicits._
+ import ds.sparkSession.implicits._
ds.map(_ + 1)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index b1987c6908..a41b465548 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
test("insert an extraStrategy") {
try {
- sqlContext.experimental.extraStrategies = TestStrategy :: Nil
+ spark.experimental.extraStrategies = TestStrategy :: Nil
val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
checkAnswer(
@@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
df.select("a", "b"),
Row("so slow", 1))
} finally {
- sqlContext.experimental.extraStrategies = Nil
+ spark.experimental.extraStrategies = Nil
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 8cbad04e23..da567db5ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = sqlContext.sessionState.planner.JoinSelection(join)
+ val planned = spark.sessionState.planner.JoinSelection(join)
assert(planned.size === 1)
}
@@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("join operator selection") {
- sqlContext.cacheManager.clearCache()
+ spark.cacheManager.clearCache()
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
Seq(
@@ -112,7 +112,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
// }
test("broadcasted hash join operator selection") {
- sqlContext.cacheManager.clearCache()
+ spark.cacheManager.clearCache()
sql("CACHE TABLE testData")
Seq(
("SELECT * FROM testData join testData2 ON key = a",
@@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("broadcasted hash outer join operator selection") {
- sqlContext.cacheManager.clearCache()
+ spark.cacheManager.clearCache()
sql("CACHE TABLE testData")
sql("CACHE TABLE testData2")
Seq(
@@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = sqlContext.sessionState.planner.JoinSelection(join)
+ val planned = spark.sessionState.planner.JoinSelection(join)
assert(planned.size === 1)
}
@@ -435,7 +435,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("broadcasted existence join operator selection") {
- sqlContext.cacheManager.clearCache()
+ spark.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
@@ -461,17 +461,17 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("cross join with broadcast") {
sql("CACHE TABLE testData")
- val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData"))
+ val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData"))
// we set the threshold is greater than statistic of the cached table testData
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
- assert(statisticSizeInByte(sqlContext.table("testData2")) >
- sqlContext.conf.autoBroadcastJoinThreshold)
+ assert(statisticSizeInByte(spark.table("testData2")) >
+ spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
- assert(statisticSizeInByte(sqlContext.table("testData")) <
- sqlContext.conf.autoBroadcastJoinThreshold)
+ assert(statisticSizeInByte(spark.table("testData")) <
+ spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 9f6c86a575..c88dfe5f24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -33,36 +33,36 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
}
after {
- sqlContext.sessionState.catalog.dropTable(
+ spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
}
test("get all tables") {
checkAnswer(
- sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
+ spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
sql("SHOW tables").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
- sqlContext.sessionState.catalog.dropTable(
+ spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("getting all tables with a database name has no impact on returned table names") {
checkAnswer(
- sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
+ spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
- sqlContext.sessionState.catalog.dropTable(
+ spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
- assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
+ assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach {
+ Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
@@ -81,9 +81,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
Row(true, "listtablessuitetable")
)
checkAnswer(
- sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
- sqlContext.dropTempTable("tables")
+ spark.catalog.dropTempTable("tables")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
new file mode 100644
index 0000000000..1732977ee5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.Suite
+
+/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */
+trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite =>
+
+ @transient var spark: SparkSession = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
+ }
+
+ override def afterEach() {
+ try {
+ resetSparkContext()
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ def resetSparkContext(): Unit = {
+ LocalSparkSession.stop(spark)
+ spark = null
+ }
+
+}
+
+object LocalSparkSession {
+ def stop(spark: SparkSession) {
+ if (spark != null) {
+ spark.stop()
+ }
+ // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+
+ /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
+ def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = {
+ try {
+ f(sc)
+ } finally {
+ stop(sc)
+ }
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index df8b3b7d87..a1a9b66c1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.types.ObjectType
abstract class QueryTest extends PlanTest {
- protected def sqlContext: SQLContext
+ protected def spark: SparkSession
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
@@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest {
expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
- sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
+ spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
checkDecoding(ds, expectedAnswer: _*)
}
@@ -267,7 +267,7 @@ abstract class QueryTest extends PlanTest {
val jsonBackPlan = try {
- TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+ TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
} catch {
case NonFatal(e) =>
fail(
@@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest {
def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
case l: LogicalRDD =>
val origin = logicalRDDs.pop()
- LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession)
+ LogicalRDD(l.output, origin.rdd)(spark)
case l: LocalRelation =>
val origin = localRelations.pop()
l.copy(data = origin.data)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1ff288cd19..e401abef29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -57,7 +57,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
- StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern)
+ StringUtils.filterPattern(spark.sessionState.functionRegistry.listFunction(), pattern)
.map(Row(_))
}
checkAnswer(sql("SHOW functions"), getFunctions("*"))
@@ -88,7 +88,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-14415: All functions should have own descriptions") {
- for (f <- sqlContext.sessionState.functionRegistry.listFunction()) {
+ for (f <- spark.sessionState.functionRegistry.listFunction()) {
if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) {
checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.")
}
@@ -102,7 +102,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
(43, 81, 24)
).toDF("a", "b", "c").registerTempTable("cachedData")
- sqlContext.cacheTable("cachedData")
+ spark.catalog.cacheTable("cachedData")
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
@@ -193,7 +193,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("grouping on nested fields") {
- sqlContext.read.json(sparkContext.parallelize(
+ spark.read.json(sparkContext.parallelize(
"""{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
.registerTempTable("rows")
@@ -211,7 +211,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-6201 IN type conversion") {
- sqlContext.read.json(
+ spark.read.json(
sparkContext.parallelize(
Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
.registerTempTable("d")
@@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-11226 Skip empty line in json file") {
- sqlContext.read.json(
+ spark.read.json(
sparkContext.parallelize(
Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "")))
.registerTempTable("d")
@@ -258,9 +258,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregation with codegen") {
// Prepare a table that we can group some rows.
- sqlContext.table("testData")
- .union(sqlContext.table("testData"))
- .union(sqlContext.table("testData"))
+ spark.table("testData")
+ .union(spark.table("testData"))
+ .union(spark.table("testData"))
.registerTempTable("testData3x")
try {
@@ -333,7 +333,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(null, null, 0) :: Nil)
} finally {
- sqlContext.dropTempTable("testData3x")
+ spark.catalog.dropTempTable("testData3x")
}
}
@@ -1041,7 +1041,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SET commands semantics using sql()") {
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
val testKey = "test.key.0"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
@@ -1082,17 +1082,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql(s"SET $nonexistentKey"),
Row(nonexistentKey, "<undefined>")
)
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
}
test("SET commands with illegal or inappropriate argument") {
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
// Set negative mapred.reduce.tasks for automatically determining
// the number of reducers is not supported
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
}
test("apply schema") {
@@ -1110,7 +1110,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
+ val df1 = spark.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -1140,7 +1140,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
+ val df2 = spark.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -1165,7 +1165,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = sqlContext.createDataFrame(rowRDD3, schema2)
+ val df3 = spark.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -1210,7 +1210,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta)
+ val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
@@ -1226,7 +1226,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-3371 Renaming a function expression with group by gives error") {
- sqlContext.udf.register("len", (s: String) => s.length)
+ spark.udf.register("len", (s: String) => s.length)
checkAnswer(
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"),
Row(1))
@@ -1409,7 +1409,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-3483 Special chars in column names") {
val data = sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
- sqlContext.read.json(data).registerTempTable("records")
+ spark.read.json(data).registerTempTable("records")
sql("SELECT `key?number1`, `key.number2` FROM records")
}
@@ -1450,15 +1450,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-4322 Grouping field with struct field as sub expression") {
- sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil))
+ spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil))
.registerTempTable("data")
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1))
- sqlContext.dropTempTable("data")
+ spark.catalog.dropTempTable("data")
- sqlContext.read.json(
+ spark.read.json(
sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2))
- sqlContext.dropTempTable("data")
+ spark.catalog.dropTempTable("data")
}
test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") {
@@ -1504,7 +1504,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-6145: ORDER BY test for nested fields") {
- sqlContext.read.json(sparkContext.makeRDD(
+ spark.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil))
.registerTempTable("nestedOrder")
@@ -1517,14 +1517,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-6145: special cases") {
- sqlContext.read.json(sparkContext.makeRDD(
+ spark.read.json(sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
- sqlContext.read.json(sparkContext.makeRDD(
+ spark.read.json(sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
.registerTempTable("t")
@@ -1628,7 +1628,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-7067: order by queries for complex ExtractValue chain") {
withTempTable("t") {
- sqlContext.read.json(sparkContext.makeRDD(
+ spark.read.json(sparkContext.makeRDD(
"""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
}
@@ -1776,7 +1776,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
// We don't support creating a temporary table while specifying a database
intercept[AnalysisException] {
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE db.t
|USING parquet
@@ -1787,7 +1787,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}.getMessage
// If you use backticks to quote the name then it's OK.
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
|USING parquet
@@ -1795,7 +1795,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
| path '$path'
|)
""".stripMargin)
- checkAnswer(sqlContext.table("`db.t`"), df)
+ checkAnswer(spark.table("`db.t`"), df)
}
}
@@ -1818,7 +1818,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("run sql directly on files") {
- val df = sqlContext.range(100).toDF()
+ val df = spark.range(100).toDF()
withTempPath(f => {
df.write.json(f.getCanonicalPath)
checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"),
@@ -1880,7 +1880,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-11303: filter should not be pushed down into sample") {
- val df = sqlContext.range(100)
+ val df = spark.range(100)
List(true, false).foreach { withReplacement =>
val sampled = df.sample(withReplacement, 0.1, 1)
val sampledOdd = sampled.filter("id % 2 != 0")
@@ -2059,7 +2059,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
// Identity udf that tracks the number of times it is called.
val countAcc = sparkContext.accumulator(0, "CallCount")
- sqlContext.udf.register("testUdf", (x: Int) => {
+ spark.udf.register("testUdf", (x: Int) => {
countAcc.++=(1)
x
})
@@ -2093,9 +2093,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
// Try disabling it via configuration.
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ spark.conf.set("spark.sql.subexpressionElimination.enabled", "false")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ spark.conf.set("spark.sql.subexpressionElimination.enabled", "true")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index ddab918629..b489b74fec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext
class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
- val _sqlContext = new SQLContext(sparkContext)
- new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
+ val spark = SparkSession.builder.getOrCreate()
+ new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 6fb1aca769..1ab562f873 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -290,7 +290,7 @@ trait StreamTest extends QueryTest with Timeouts {
verify(currentStream == null, "stream already running")
lastStream = currentStream
currentStream =
- sqlContext
+ spark
.streams
.startQuery(
StreamExecution.nextName,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 6809f26968..c7b95c2683 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -281,7 +281,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("number format function") {
- val df = sqlContext.range(1)
+ val df = spark.range(1)
checkAnswer(
df.select(format_number(lit(5L), 4)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ec950332c5..427f24a9f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("built-in fixed arity expressions") {
- val df = sqlContext.emptyDataFrame
+ val df = spark.emptyDataFrame
df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
}
@@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
- sqlContext.dropTempTable("tmp_table")
+ spark.catalog.dropTempTable("tmp_table")
}
test("SPARK-8005 input_file_name") {
withTempPath { dir =>
val data = sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
val answer = sql("select input_file_name() from test_table").head().getString(0)
assert(answer.contains(dir.getCanonicalPath))
assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
- sqlContext.dropTempTable("test_table")
+ spark.catalog.dropTempTable("test_table")
}
}
test("error reporting for incorrect number of arguments") {
- val df = sqlContext.emptyDataFrame
+ val df = spark.emptyDataFrame
val e = intercept[AnalysisException] {
df.selectExpr("substr('abcd', 2, 3, 4)")
}
@@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("error reporting for undefined functions") {
- val df = sqlContext.emptyDataFrame
+ val df = spark.emptyDataFrame
val e = intercept[AnalysisException] {
df.selectExpr("a_function_that_does_not_exist()")
}
@@ -88,22 +88,22 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("Simple UDF") {
- sqlContext.udf.register("strLenScala", (_: String).length)
+ spark.udf.register("strLenScala", (_: String).length)
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
- sqlContext.udf.register("random0", () => { Math.random()})
+ spark.udf.register("random0", () => { Math.random()})
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
- sqlContext.udf.register("strLenScala", (_: String).length + (_: Int))
+ spark.udf.register("strLenScala", (_: String).length + (_: Int))
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("UDF in a WHERE") {
- sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 })
+ spark.udf.register("oneArgFilter", (n: Int) => { n > 80 })
val df = sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
@@ -115,7 +115,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDF in a HAVING") {
- sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 })
+ spark.udf.register("havingFilter", (n: Long) => { n > 5 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -134,7 +134,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDF in a GROUP BY") {
- sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 })
+ spark.udf.register("groupFunction", (n: Int) => { n > 10 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -151,10 +151,10 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDFs everywhere") {
- sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 })
- sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 })
- sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 })
- sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 })
+ spark.udf.register("groupFunction", (n: Int) => { n > 10 })
+ spark.udf.register("havingFilter", (n: Long) => { n > 2000 })
+ spark.udf.register("whereFilter", (n: Int) => { n < 150 })
+ spark.udf.register("timesHundred", (n: Long) => { n * 100 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -173,7 +173,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("struct UDF") {
- sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
+ spark.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
sql("SELECT returnStruct('test', 'test2') as ret")
@@ -182,27 +182,27 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("udf that is transformed") {
- sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ spark.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
test("type coercion for udf inputs") {
- sqlContext.udf.register("intExpected", (x: Int) => x)
+ spark.udf.register("intExpected", (x: Int) => x)
// pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
test("udf in different types") {
- sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) })
- sqlContext.udf.register("decimalDataFunc",
+ spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) })
+ spark.udf.register("decimalDataFunc",
(a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) })
- sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) })
- sqlContext.udf.register("arrayDataFunc",
+ spark.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) })
+ spark.udf.register("arrayDataFunc",
(data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) })
- sqlContext.udf.register("mapDataFunc",
+ spark.udf.register("mapDataFunc",
(data: scala.collection.Map[Int, String]) => { data })
- sqlContext.udf.register("complexDataFunc",
+ spark.udf.register("complexDataFunc",
(m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } )
checkAnswer(
@@ -235,7 +235,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") {
- val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) })
+ val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) })
// Without the fix, this will fail because we fail to cast data type of b to string
// because myUDF does not know its input data type. With the fix, this query should not
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index a49aaa8b73..3057e016c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -94,7 +94,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
}
test("UDTs and UDFs") {
- sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector])
+ spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
sql("SELECT testType(features) from points"),
@@ -106,7 +106,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
val path = dir.getCanonicalPath
pointsRDD.write.parquet(path)
checkAnswer(
- sqlContext.read.parquet(path),
+ spark.read.parquet(path),
Seq(
Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
@@ -118,7 +118,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
val path = dir.getCanonicalPath
pointsRDD.repartition(1).write.parquet(path)
checkAnswer(
- sqlContext.read.parquet(path),
+ spark.read.parquet(path),
Seq(
Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
@@ -146,7 +146,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
))
val stringRDD = sparkContext.parallelize(data)
- val jsonRDD = sqlContext.read.schema(schema).json(stringRDD)
+ val jsonRDD = spark.read.schema(schema).json(stringRDD)
checkAnswer(
jsonRDD,
Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) ::
@@ -167,7 +167,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
))
val stringRDD = sparkContext.parallelize(data)
- val jsonDataset = sqlContext.read.schema(schema).json(stringRDD)
+ val jsonDataset = spark.read.schema(schema).json(stringRDD)
.as[(Int, UDT.MyDenseVector)]
checkDataset(
jsonDataset,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 01d485ce2d..70a00a43f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql.execution
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.TestSQLContext
class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -251,7 +250,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
def withSQLContext(
- f: SQLContext => Unit,
+ f: SparkSession => Unit,
targetNumPostShufflePartitions: Int,
minNumPostShufflePartitions: Option[Int]): Unit = {
val sparkConf =
@@ -272,9 +271,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
case None =>
sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1")
}
- val sparkContext = new SparkContext(sparkConf)
- val sqlContext = new TestSQLContext(sparkContext)
- try f(sqlContext) finally sparkContext.stop()
+
+ val spark = SparkSession.builder
+ .config(sparkConf)
+ .getOrCreate()
+ try f(spark) finally spark.stop()
}
Seq(Some(3), None).foreach { minNumPostShufflePartitions =>
@@ -284,9 +285,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: aggregate operator$testNameNote") {
- val test = { sqlContext: SQLContext =>
+ val test = { spark: SparkSession =>
val df =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 20 as key", "id as value")
val agg = df.groupBy("key").count
@@ -294,7 +295,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Check the answer first.
checkAnswer(
agg,
- sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect())
+ spark.range(0, 20).selectExpr("id", "50 as cnt").collect())
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
@@ -325,13 +326,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: join operator$testNameNote") {
- val test = { sqlContext: SQLContext =>
+ val test = { spark: SparkSession =>
val df1 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key1", "id as value1")
val df2 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key2", "id as value2")
@@ -339,10 +340,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Check the answer first.
val expectedAnswer =
- sqlContext
+ spark
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
- .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
+ .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
join,
expectedAnswer.collect())
@@ -376,16 +377,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: complex query 1$testNameNote") {
- val test = { sqlContext: SQLContext =>
+ val test = { spark: SparkSession =>
val df1 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key1", "id as value1")
.groupBy("key1")
.count
.toDF("key1", "cnt1")
val df2 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key2", "id as value2")
.groupBy("key2")
@@ -396,7 +397,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Check the answer first.
val expectedAnswer =
- sqlContext
+ spark
.range(0, 500)
.selectExpr("id", "2 as cnt")
checkAnswer(
@@ -428,16 +429,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: complex query 2$testNameNote") {
- val test = { sqlContext: SQLContext =>
+ val test = { spark: SparkSession =>
val df1 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key1", "id as value1")
.groupBy("key1")
.count
.toDF("key1", "cnt1")
val df2 =
- sqlContext
+ spark
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key2", "id as value2")
@@ -448,7 +449,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Check the answer first.
val expectedAnswer =
- sqlContext
+ spark
.range(0, 1000)
.selectExpr("id % 500 as key", "2 as cnt", "id as value")
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index ba16810cee..36cde3233d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -50,7 +50,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
}
test("BroadcastExchange same result") {
- val df = sqlContext.range(10)
+ val df = spark.range(10)
val plan = df.queryExecution.executedPlan
val output = plan.output
assert(plan sameResult plan)
@@ -75,7 +75,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
}
test("ShuffleExchange same result") {
- val df = sqlContext.range(10)
+ val df = spark.range(10)
val plan = df.queryExecution.executedPlan
val output = plan.output
assert(plan sameResult plan)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3b2911d056..d2e1ea12fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext {
setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
- val planner = sqlContext.sessionState.planner
+ val planner = spark.sessionState.planner
import planner._
val plannedOption = Aggregation(query).headOption
val planned =
@@ -78,7 +78,7 @@ class PlannerSuite extends SharedSQLContext {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = sparkContext.parallelize(row :: Nil)
- sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
+ spark.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
@@ -136,7 +136,7 @@ class PlannerSuite extends SharedSQLContext {
sql("CACHE TABLE tiny")
val a = testData.as("a")
- val b = sqlContext.table("tiny").as("b")
+ val b = spark.table("tiny").as("b")
val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoinExec => join }
@@ -145,7 +145,7 @@ class PlannerSuite extends SharedSQLContext {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join")
- sqlContext.clearCache()
+ spark.catalog.clearCache()
}
}
}
@@ -154,8 +154,8 @@ class PlannerSuite extends SharedSQLContext {
withTempPath { file =>
val path = file.getCanonicalPath
testData.write.parquet(path)
- val df = sqlContext.read.parquet(path)
- sqlContext.registerDataFrameAsTable(df, "testPushed")
+ val df = spark.read.parquet(path)
+ spark.wrapped.registerDataFrameAsTable(df, "testPushed")
withTempTable("testPushed") {
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
@@ -295,7 +295,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
@@ -315,7 +315,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
}
@@ -333,7 +333,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
@@ -353,7 +353,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"Exchange should not have been added:\n$outputPlan")
@@ -376,7 +376,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$outputPlan")
@@ -392,7 +392,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingB)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: SortExec => true }.isEmpty) {
fail(s"Sort should have been added:\n$outputPlan")
@@ -408,7 +408,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingA)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: SortExec => true }.nonEmpty) {
fail(s"No sorts should have been added:\n$outputPlan")
@@ -425,7 +425,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq(orderingA, orderingB)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: SortExec => true }.isEmpty) {
fail(s"Sort should have been added:\n$outputPlan")
@@ -444,7 +444,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq.empty)),
None)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) {
fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
@@ -464,7 +464,7 @@ class PlannerSuite extends SharedSQLContext {
requiredChildOrdering = Seq(Seq.empty)),
None)
- val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) {
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
@@ -493,7 +493,7 @@ class PlannerSuite extends SharedSQLContext {
shuffle,
shuffle)
- val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan)
+ val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan)
if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) {
fail(s"Should re-use the shuffle:\n$outputPlan")
}
@@ -510,7 +510,7 @@ class PlannerSuite extends SharedSQLContext {
ShuffleExchange(finalPartitioning, inputPlan),
ShuffleExchange(finalPartitioning, inputPlan))
- val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2)
+ val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2)
if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) {
fail(s"Should re-use the two shuffles:\n$outputPlan2")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
index c9f517ca34..ad41111bec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.util.Properties
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
class SQLExecutionSuite extends SparkFunSuite {
@@ -50,16 +50,19 @@ class SQLExecutionSuite extends SparkFunSuite {
}
test("concurrent query execution with fork-join pool (SPARK-13747)") {
- val sc = new SparkContext("local[*]", "test")
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder
+ .master("local[*]")
+ .appName("test")
+ .getOrCreate()
+
+ import spark.implicits._
try {
// Should not throw IllegalArgumentException
(1 to 100).par.foreach { _ =>
- sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count()
+ spark.sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count()
}
} finally {
- sc.stop()
+ spark.sparkContext.stop()
}
}
@@ -67,8 +70,8 @@ class SQLExecutionSuite extends SparkFunSuite {
* Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently.
*/
private def testConcurrentQueryExecution(sc: SparkContext): Unit = {
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder.getOrCreate()
+ import spark.implicits._
// Initialize local properties. This is necessary for the test to pass.
sc.getLocalProperties
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 073e0b3f00..d7eae21f9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -21,7 +21,7 @@ import scala.language.implicitConversions
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.test.SQLTestUtils
@@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SQLTestUtils
* class's test helper methods can be used, see [[SortSuite]].
*/
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
- protected def sqlContext: SQLContext
+ protected def spark: SparkSession
/**
* Runs the plan and makes sure the answer matches the expected result.
@@ -90,9 +90,10 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
- case Some(errorMessage) => fail(errorMessage)
- case None =>
+ SparkPlanTest
+ .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
}
}
@@ -114,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
- input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
+ input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -141,13 +142,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
- sqlContext: SQLContext): Option[String] = {
+ spark: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan, sqlContext)
+ executePlan(expectedOutputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -162,7 +163,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan, sqlContext)
+ executePlan(outputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -202,12 +203,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
- sqlContext: SQLContext): Option[String] = {
+ spark: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan, sqlContext)
+ executePlan(outputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -230,8 +231,8 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
- val execution = new QueryExecution(sqlContext.sparkSession, null) {
+ private def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = {
+ val execution = new QueryExecution(spark.sparkSession, null) {
override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 233104ae84..ada60f6919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
test("range/filter should be combined") {
- val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
+ val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
assert(df.collect() === Array(Row(2)))
}
test("Aggregate should be included in WholeStageCodegen") {
- val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+ val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -44,7 +44,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
}
test("Aggregate with grouping keys should be included in WholeStageCodegen") {
- val df = sqlContext.range(3).groupBy("id").count().orderBy("id")
+ val df = spark.range(3).groupBy("id").count().orderBy("id")
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -53,10 +53,10 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
}
test("BroadcastHashJoin should be included in WholeStageCodegen") {
- val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
+ val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
val schema = new StructType().add("k", IntegerType).add("v", StringType)
- val smallDF = sqlContext.createDataFrame(rdd, schema)
- val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
+ val smallDF = spark.createDataFrame(rdd, schema)
+ val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
@@ -64,7 +64,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
}
test("Sort should be included in WholeStageCodegen") {
- val df = sqlContext.range(3, 0, -1).toDF().sort(col("id"))
+ val df = spark.range(3, 0, -1).toDF().sort(col("id"))
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -75,7 +75,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
test("MapElements should be included in WholeStageCodegen") {
import testImplicits._
- val ds = sqlContext.range(10).map(_.toString)
+ val ds = spark.range(10).map(_.toString)
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -84,7 +84,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
}
test("typed filter should be included in WholeStageCodegen") {
- val ds = sqlContext.range(10).filter(_ % 2 == 0)
+ val ds = spark.range(10).filter(_ % 2 == 0)
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -93,7 +93,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
}
test("back-to-back typed filter should be included in WholeStageCodegen") {
- val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
+ val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 50c8745a28..88269a6a2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
@@ -32,7 +33,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
setupTestData()
test("simple columnar query") {
- val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan
+ val plan = spark.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -42,14 +43,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// TODO: Improve this test when we have better statistics
sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
- sqlContext.cacheTable("sizeTst")
+ spark.catalog.cacheTable("sizeTst")
assert(
- sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
- sqlContext.conf.autoBroadcastJoinThreshold)
+ spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
+ spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD))
}
test("projection") {
- val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan
+ val plan = spark.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
@@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
- val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan
+ val plan = spark.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
- sqlContext.cacheTable("repeatedData")
+ spark.catalog.cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
@@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
- sqlContext.cacheTable("nullableRepeatedData")
+ spark.catalog.cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
@@ -97,7 +98,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq)
- sqlContext.cacheTable("timestamps")
+ spark.catalog.cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
@@ -109,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
- sqlContext.cacheTable("withEmptyParts")
+ spark.catalog.cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
@@ -178,35 +179,35 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
(i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
- sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
+ spark.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
- sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan
+ spark.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
- sqlContext.isCached("InMemoryCache_different_data_types"),
+ spark.catalog.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
- sqlContext.table("InMemoryCache_different_data_types").collect())
- sqlContext.dropTempTable("InMemoryCache_different_data_types")
+ spark.table("InMemoryCache_different_data_types").collect())
+ spark.catalog.dropTempTable("InMemoryCache_different_data_types")
}
test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") {
- val df = sqlContext.range(1, 100).selectExpr("id % 10 as id")
+ val df = spark.range(1, 100).selectExpr("id % 10 as id")
.rdd.map(id => Tuple1(s"str_$id")).toDF("i")
val cached = df.cache()
// count triggers the caching action. It should not throw.
cached.count()
// Make sure, the DataFrame is indeed cached.
- assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty)
+ assert(spark.cacheManager.lookupCachedData(cached).nonEmpty)
// Check result.
checkAnswer(
cached,
- sqlContext.range(1, 100).selectExpr("id % 10 as id")
+ spark.range(1, 100).selectExpr("id % 10 as id")
.rdd.map(id => Tuple1(s"str_$id")).toDF("i")
)
@@ -215,7 +216,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") {
- val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s")
+ val data = spark.range(10).selectExpr("id", "cast(id as string) as s")
data.cache()
assert(data.count() === 10)
assert(data.filter($"s" === "3").count() === 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 9164074a3e..48c798986b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -32,23 +32,24 @@ class PartitionBatchPruningSuite
import testImplicits._
- private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize
- private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning
+ private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
+ private lazy val originalInMemoryPartitionPruning =
+ spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
override protected def beforeAll(): Unit = {
super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
- sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
+ spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, 10)
// Enable in-memory partition pruning
- sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
+ spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true)
// Enable in-memory table scan accumulators
- sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
+ spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
try {
- sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
- sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+ spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, originalColumnBatchSize)
+ spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, originalInMemoryPartitionPruning)
} finally {
super.afterAll()
}
@@ -63,12 +64,12 @@ class PartitionBatchPruningSuite
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")
- sqlContext.cacheTable("pruningData")
+ spark.catalog.cacheTable("pruningData")
}
override protected def afterEach(): Unit = {
try {
- sqlContext.uncacheTable("pruningData")
+ spark.catalog.uncacheTable("pruningData")
} finally {
super.afterEach()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 3586ddf7b6..5fbab2382a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -37,7 +37,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
override def afterEach(): Unit = {
try {
// drop all databases, tables and functions after each test
- sqlContext.sessionState.catalog.reset()
+ spark.sessionState.catalog.reset()
} finally {
super.afterEach()
}
@@ -66,7 +66,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
private def createDatabase(catalog: SessionCatalog, name: String): Unit = {
catalog.createDatabase(
- CatalogDatabase(name, "", sqlContext.conf.warehousePath, Map()), ignoreIfExists = false)
+ CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()),
+ ignoreIfExists = false)
}
private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = {
@@ -111,7 +112,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("the qualified path of a database is stored in the catalog") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
withTempDir { tmpDir =>
val path = tmpDir.toString
@@ -274,7 +275,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
databaseNames.foreach { dbName =>
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
- assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks))
+ assert(!spark.sessionState.catalog.databaseExists(dbNameWithoutBackTicks))
var message = intercept[AnalysisException] {
sql(s"DROP DATABASE $dbName")
@@ -334,7 +335,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("create table in default db") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent1 = TableIdentifier("tab1", None)
createTable(catalog, tableIdent1)
val expectedTableIdent = tableIdent1.copy(database = Some("default"))
@@ -343,7 +344,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("create table in a specific db") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
createDatabase(catalog, "dbx")
val tableIdent1 = TableIdentifier("tab1", Some("dbx"))
createTable(catalog, tableIdent1)
@@ -352,7 +353,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: rename") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent1 = TableIdentifier("tab1", Some("dbx"))
val tableIdent2 = TableIdentifier("tab2", Some("dbx"))
createDatabase(catalog, "dbx")
@@ -444,7 +445,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: set properties") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -471,7 +472,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: unset properties") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -512,7 +513,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: bucketing is not supported") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -523,7 +524,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: skew is not supported") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -560,7 +561,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("alter table: rename partition") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1")
val part2 = Map("b" -> "2")
@@ -661,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("drop table - temporary table") {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
sql(
"""
|CREATE TEMPORARY TABLE tab1
@@ -686,7 +687,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
private def testDropTable(isDatasourceTable: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -705,7 +706,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
// SQLContext does not support create view. Log an error message, if tab1 does not exists
sql("DROP VIEW tab1")
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -726,7 +727,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
private def testSetLocation(isDatasourceTable: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val partSpec = Map("a" -> "1")
createDatabase(catalog, "dbx")
@@ -784,7 +785,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
private def testSetSerde(isDatasourceTable: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
createTable(catalog, tableIdent)
@@ -830,7 +831,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
private def testAddPartitions(isDatasourceTable: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1")
val part2 = Map("b" -> "2")
@@ -880,7 +881,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
private def testDropPartitions(isDatasourceTable: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1")
val part2 = Map("b" -> "2")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index ac2af77a6e..52dda8c6ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -281,7 +281,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
))
val fakeRDD = new FileScanRDD(
- sqlContext.sparkSession,
+ spark,
(file: PartitionedFile) => Iterator.empty,
Seq(partition)
)
@@ -399,7 +399,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
util.stringToFile(file, "*" * size)
}
- val df = sqlContext.read
+ val df = spark.read
.format(classOf[TestFileFormat].getName)
.load(tempDir.getCanonicalPath)
@@ -409,7 +409,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
l.copy(relation =
r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil))))
}
- Dataset.ofRows(sqlContext.sparkSession, bucketed)
+ Dataset.ofRows(spark, bucketed)
} else {
df
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala
index 297731c70c..89d57653ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala
@@ -27,7 +27,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {
test("sizeInBytes should be the total size of all files") {
withTempDir{ dir =>
dir.delete()
- sqlContext.range(1000).write.parquet(dir.toString)
+ spark.range(1000).write.parquet(dir.toString)
// ignore hidden files
val allFiles = dir.listFiles(new FilenameFilter {
override def accept(dir: File, name: String): Boolean = {
@@ -35,7 +35,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {
}
})
val totalSize = allFiles.map(_.length()).sum
- val df = sqlContext.read.parquet(dir.toString)
+ val df = spark.read.parquet(dir.toString)
assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 28e59055fa..b6cdc8cfab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -91,7 +91,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("simple csv test") {
- val cars = sqlContext
+ val cars = spark
.read
.format("csv")
.option("header", "false")
@@ -101,7 +101,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("simple csv test with calling another function to load") {
- val cars = sqlContext
+ val cars = spark
.read
.option("header", "false")
.csv(testFile(carsFile))
@@ -110,7 +110,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("simple csv test with type inference") {
- val cars = sqlContext
+ val cars = spark
.read
.format("csv")
.option("header", "true")
@@ -121,7 +121,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test inferring booleans") {
- val result = sqlContext.read
+ val result = spark.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
@@ -133,7 +133,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test with alternative delimiter and quote") {
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true"))
.load(testFile(carsAltFile))
@@ -142,7 +142,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("parse unescaped quotes with maxCharsPerColumn") {
- val rows = sqlContext.read
+ val rows = spark.read
.format("csv")
.option("maxCharsPerColumn", "4")
.load(testFile(unescapedQuotesFile))
@@ -154,7 +154,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("bad encoding name") {
val exception = intercept[UnsupportedCharsetException] {
- sqlContext
+ spark
.read
.format("csv")
.option("charset", "1-9588-osi")
@@ -166,7 +166,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test different encoding") {
// scalastyle:off
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE carsTable USING csv
|OPTIONS (path "${testFile(carsFile8859)}", header "true",
@@ -174,12 +174,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
""".stripMargin.replaceAll("\n", " "))
// scalastyle:on
- verifyCars(sqlContext.table("carsTable"), withHeader = true)
+ verifyCars(spark.table("carsTable"), withHeader = true)
}
test("test aliases sep and encoding for delimiter and charset") {
// scalastyle:off
- val cars = sqlContext
+ val cars = spark
.read
.format("csv")
.option("header", "true")
@@ -192,17 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("DDL test with tab separated file") {
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE carsTable USING csv
|OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t")
""".stripMargin.replaceAll("\n", " "))
- verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false)
+ verifyCars(spark.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false)
}
test("DDL test parsing decimal type") {
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE carsTable
|(yearMade double, makeName string, modelName string, priceTag decimal,
@@ -212,11 +212,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
""".stripMargin.replaceAll("\n", " "))
assert(
- sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1)
+ spark.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1)
}
test("test for DROPMALFORMED parsing mode") {
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.options(Map("header" -> "true", "mode" -> "dropmalformed"))
.load(testFile(carsFile))
@@ -226,7 +226,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test for FAILFAST parsing mode") {
val exception = intercept[SparkException]{
- sqlContext.read
+ spark.read
.format("csv")
.options(Map("header" -> "true", "mode" -> "failfast"))
.load(testFile(carsFile)).collect()
@@ -236,7 +236,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test for tokens more than the fields in the schema") {
- val cars = sqlContext
+ val cars = spark
.read
.format("csv")
.option("header", "false")
@@ -247,7 +247,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test with null quote character") {
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.option("header", "true")
.option("quote", "")
@@ -258,7 +258,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test with empty file and known schema") {
- val result = sqlContext.read
+ val result = spark.read
.format("csv")
.schema(StructType(List(StructField("column", StringType, false))))
.load(testFile(emptyFile))
@@ -268,25 +268,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("DDL test with empty file") {
- sqlContext.sql(s"""
+ spark.sql(s"""
|CREATE TEMPORARY TABLE carsTable
|(yearMade double, makeName string, modelName string, comments string, grp string)
|USING csv
|OPTIONS (path "${testFile(emptyFile)}", header "false")
""".stripMargin.replaceAll("\n", " "))
- assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0)
+ assert(spark.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0)
}
test("DDL test with schema") {
- sqlContext.sql(s"""
+ spark.sql(s"""
|CREATE TEMPORARY TABLE carsTable
|(yearMade double, makeName string, modelName string, comments string, blank string)
|USING csv
|OPTIONS (path "${testFile(carsFile)}", header "true")
""".stripMargin.replaceAll("\n", " "))
- val cars = sqlContext.table("carsTable")
+ val cars = spark.table("carsTable")
verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false)
assert(
cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank"))
@@ -295,7 +295,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("save csv") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.option("header", "true")
.load(testFile(carsFile))
@@ -304,7 +304,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.option("header", "true")
.csv(csvDir)
- val carsCopy = sqlContext.read
+ val carsCopy = spark.read
.format("csv")
.option("header", "true")
.load(csvDir)
@@ -316,7 +316,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("save csv with quote") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.option("header", "true")
.load(testFile(carsFile))
@@ -327,7 +327,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.option("quote", "\"")
.save(csvDir)
- val carsCopy = sqlContext.read
+ val carsCopy = spark.read
.format("csv")
.option("header", "true")
.option("quote", "\"")
@@ -338,7 +338,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("commented lines in CSV data") {
- val results = sqlContext.read
+ val results = spark.read
.format("csv")
.options(Map("comment" -> "~", "header" -> "false"))
.load(testFile(commentsFile))
@@ -353,7 +353,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("inferring schema with commented lines in CSV data") {
- val results = sqlContext.read
+ val results = spark.read
.format("csv")
.options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true"))
.load(testFile(commentsFile))
@@ -372,7 +372,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
"header" -> "true",
"inferSchema" -> "true",
"dateFormat" -> "dd/MM/yyyy hh:mm")
- val results = sqlContext.read
+ val results = spark.read
.format("csv")
.options(options)
.load(testFile(datesFile))
@@ -393,7 +393,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
"header" -> "true",
"inferSchema" -> "false",
"dateFormat" -> "dd/MM/yyyy hh:mm")
- val results = sqlContext.read
+ val results = spark.read
.format("csv")
.options(options)
.schema(customSchema)
@@ -416,7 +416,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("setting comment to null disables comment support") {
- val results = sqlContext.read
+ val results = spark.read
.format("csv")
.options(Map("comment" -> "", "header" -> "false"))
.load(testFile(disableCommentsFile))
@@ -439,7 +439,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
StructField("model", StringType, nullable = false),
StructField("comment", StringType, nullable = true),
StructField("blank", StringType, nullable = true)))
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.schema(dataSchema)
.options(Map("header" -> "true", "nullValue" -> "null"))
@@ -454,7 +454,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("save csv with compression codec option") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.option("header", "true")
.load(testFile(carsFile))
@@ -468,7 +468,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val compressedFiles = new File(csvDir).listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".csv.gz")))
- val carsCopy = sqlContext.read
+ val carsCopy = spark.read
.format("csv")
.option("header", "true")
.load(csvDir)
@@ -486,7 +486,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
)
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
- val cars = sqlContext.read
+ val cars = spark.read
.format("csv")
.option("header", "true")
.options(extraOptions)
@@ -502,7 +502,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val compressedFiles = new File(csvDir).listFiles()
assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz")))
- val carsCopy = sqlContext.read
+ val carsCopy = spark.read
.format("csv")
.option("header", "true")
.options(extraOptions)
@@ -513,7 +513,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("Schema inference correctly identifies the datatype when data is sparse.") {
- val df = sqlContext.read
+ val df = spark.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
@@ -525,7 +525,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("old csv data source name works") {
- val cars = sqlContext
+ val cars = spark
.read
.format("com.databricks.spark.csv")
.option("header", "false")
@@ -535,7 +535,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("nulls, NaNs and Infinity values can be parsed") {
- val numbers = sqlContext
+ val numbers = spark
.read
.format("csv")
.schema(StructType(List(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
index 1742df31bb..c31dffedbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
@@ -27,16 +27,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
test("allowComments off") {
val str = """{'name': /* hello */ 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
test("allowComments on") {
val str = """{'name': /* hello */ 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowComments", "true").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowComments", "true").json(rdd)
assert(df.schema.head.name == "name")
assert(df.first().getString(0) == "Reynold Xin")
@@ -44,16 +44,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
test("allowSingleQuotes off") {
val str = """{'name': 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowSingleQuotes", "false").json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
test("allowSingleQuotes on") {
val str = """{'name': 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.json(rdd)
assert(df.schema.head.name == "name")
assert(df.first().getString(0) == "Reynold Xin")
@@ -61,16 +61,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
test("allowUnquotedFieldNames off") {
val str = """{name: 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
test("allowUnquotedFieldNames on") {
val str = """{name: 'Reynold Xin'}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd)
assert(df.schema.head.name == "name")
assert(df.first().getString(0) == "Reynold Xin")
@@ -78,16 +78,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
test("allowNumericLeadingZeros off") {
val str = """{"age": 0018}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
test("allowNumericLeadingZeros on") {
val str = """{"age": 0018}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd)
assert(df.schema.head.name == "age")
assert(df.first().getLong(0) == 18)
@@ -97,16 +97,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
// JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS.
ignore("allowNonNumericNumbers off") {
val str = """{"age": NaN}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
ignore("allowNonNumericNumbers on") {
val str = """{"age": NaN}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd)
assert(df.schema.head.name == "age")
assert(df.first().getDouble(0).isNaN)
@@ -114,16 +114,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
test("allowBackslashEscapingAnyCharacter off") {
val str = """{"name": "Cazen Lee", "price": "\$10"}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd)
assert(df.schema.head.name == "_corrupt_record")
}
test("allowBackslashEscapingAnyCharacter on") {
val str = """{"name": "Cazen Lee", "price": "\$10"}"""
- val rdd = sqlContext.sparkContext.parallelize(Seq(str))
- val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd)
+ val rdd = spark.sparkContext.parallelize(Seq(str))
+ val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd)
assert(df.schema.head.name == "name")
assert(df.schema.last.name == "price")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index b1279abd63..63fe4658d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -229,7 +229,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Complex field and type inferring with null in sampling") {
- val jsonDF = sqlContext.read.json(jsonNullStruct)
+ val jsonDF = spark.read.json(jsonNullStruct)
val expectedSchema = StructType(
StructField("headers", StructType(
StructField("Charset", StringType, true) ::
@@ -248,7 +248,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Primitive field and type inferring") {
- val jsonDF = sqlContext.read.json(primitiveFieldAndType)
+ val jsonDF = spark.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -276,7 +276,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Complex field and type inferring") {
- val jsonDF = sqlContext.read.json(complexFieldAndType1)
+ val jsonDF = spark.read.json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
@@ -375,7 +375,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("GetField operation on complex data type") {
- val jsonDF = sqlContext.read.json(complexFieldAndType1)
+ val jsonDF = spark.read.json(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -391,7 +391,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in primitive field values") {
- val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = spark.read.json(primitiveFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
@@ -463,7 +463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
ignore("Type conflict in primitive field values (Ignored)") {
- val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = spark.read.json(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expression.
@@ -516,7 +516,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in complex field values") {
- val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict)
+ val jsonDF = spark.read.json(complexFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("array", ArrayType(LongType, true), true) ::
@@ -540,7 +540,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in array elements") {
- val jsonDF = sqlContext.read.json(arrayElementTypeConflict)
+ val jsonDF = spark.read.json(arrayElementTypeConflict)
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
@@ -568,7 +568,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Handling missing fields") {
- val jsonDF = sqlContext.read.json(missingFields)
+ val jsonDF = spark.read.json(missingFields)
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
@@ -588,7 +588,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = sqlContext.read.json(path)
+ val jsonDF = spark.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -620,7 +620,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path)
+ val jsonDF = spark.read.option("primitivesAsString", "true").json(path)
val expectedSchema = StructType(
StructField("bigInteger", StringType, true) ::
@@ -648,7 +648,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Loading a JSON dataset primitivesAsString returns complex fields as strings") {
- val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1)
+ val jsonDF = spark.read.option("primitivesAsString", "true").json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
@@ -746,7 +746,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") {
- val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType)
+ val jsonDF = spark.read.option("prefersDecimal", "true").json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -777,7 +777,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val mixedIntegerAndDoubleRecords = sparkContext.parallelize(
"""{"a": 3, "b": 1.1}""" ::
s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil)
- val jsonDF = sqlContext.read
+ val jsonDF = spark.read
.option("prefersDecimal", "true")
.json(mixedIntegerAndDoubleRecords)
@@ -796,7 +796,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Infer big integers correctly even when it does not fit in decimal") {
- val jsonDF = sqlContext.read
+ val jsonDF = spark.read
.json(bigIntegerRecords)
// The value in `a` field will be a double as it does not fit in decimal. For `b` field,
@@ -810,7 +810,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Infer floating-point values correctly even when it does not fit in decimal") {
- val jsonDF = sqlContext.read
+ val jsonDF = spark.read
.option("prefersDecimal", "true")
.json(floatingValueRecords)
@@ -823,7 +823,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(expectedSchema === jsonDF.schema)
checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01)))
- val mergedJsonDF = sqlContext.read
+ val mergedJsonDF = spark.read
.option("prefersDecimal", "true")
.json(floatingValueRecords ++ bigIntegerRecords)
@@ -881,7 +881,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
- val jsonDF1 = sqlContext.read.schema(schema).json(path)
+ val jsonDF1 = spark.read.schema(schema).json(path)
assert(schema === jsonDF1.schema)
@@ -898,7 +898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"this is a simple string.")
)
- val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType)
+ val jsonDF2 = spark.read.schema(schema).json(primitiveFieldAndType)
assert(schema === jsonDF2.schema)
@@ -919,7 +919,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("Applying schemas with MapType") {
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1)
+ val jsonWithSimpleMap = spark.read.schema(schemaWithSimpleMap).json(mapType1)
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
@@ -947,7 +947,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val schemaWithComplexMap = StructType(
StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
- val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2)
+ val jsonWithComplexMap = spark.read.schema(schemaWithComplexMap).json(mapType2)
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
@@ -973,7 +973,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-2096 Correctly parse dot notations") {
- val jsonDF = sqlContext.read.json(complexFieldAndType2)
+ val jsonDF = spark.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -991,7 +991,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-3390 Complex arrays") {
- val jsonDF = sqlContext.read.json(complexFieldAndType2)
+ val jsonDF = spark.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -1014,7 +1014,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-3308 Read top level JSON arrays") {
- val jsonDF = sqlContext.read.json(jsonArray)
+ val jsonDF = spark.read.json(jsonArray)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -1035,7 +1035,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
StructField("a", StringType, true) :: Nil)
// `FAILFAST` mode should throw an exception for corrupt records.
val exceptionOne = intercept[SparkException] {
- sqlContext.read
+ spark.read
.option("mode", "FAILFAST")
.json(corruptRecords)
.collect()
@@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {"))
val exceptionTwo = intercept[SparkException] {
- sqlContext.read
+ spark.read
.option("mode", "FAILFAST")
.schema(schema)
.json(corruptRecords)
@@ -1060,7 +1060,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val schemaTwo = StructType(
StructField("a", StringType, true) :: Nil)
// `DROPMALFORMED` mode should skip corrupt records
- val jsonDFOne = sqlContext.read
+ val jsonDFOne = spark.read
.option("mode", "DROPMALFORMED")
.json(corruptRecords)
checkAnswer(
@@ -1069,7 +1069,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
assert(jsonDFOne.schema === schemaOne)
- val jsonDFTwo = sqlContext.read
+ val jsonDFTwo = spark.read
.option("mode", "DROPMALFORMED")
.schema(schemaTwo)
.json(corruptRecords)
@@ -1083,7 +1083,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// Test if we can query corrupt records.
withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") {
withTempTable("jsonTable") {
- val jsonDF = sqlContext.read.json(corruptRecords)
+ val jsonDF = spark.read.json(corruptRecords)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
StructField("_unparsed", StringType, true) ::
@@ -1134,7 +1134,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-13953 Rename the corrupt record field via option") {
- val jsonDF = sqlContext.read
+ val jsonDF = spark.read
.option("columnNameOfCorruptRecord", "_malformed")
.json(corruptRecords)
val schema = StructType(
@@ -1155,7 +1155,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-4068: nulls in arrays") {
- val jsonDF = sqlContext.read.json(nullsInArrays)
+ val jsonDF = spark.read.json(nullsInArrays)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -1201,7 +1201,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
+ val df1 = spark.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDF
val result = df2.toJSON.collect()
@@ -1224,7 +1224,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = sqlContext.createDataFrame(rowRDD2, schema2)
+ val df3 = spark.createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDF
val result2 = df4.toJSON.collect()
@@ -1232,8 +1232,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
- val jsonDF = sqlContext.read.json(primitiveFieldAndType)
- val primTable = sqlContext.read.json(jsonDF.toJSON.rdd)
+ val jsonDF = spark.read.json(primitiveFieldAndType)
+ val primTable = spark.read.json(jsonDF.toJSON.rdd)
primTable.registerTempTable("primitiveTable")
checkAnswer(
sql("select * from primitiveTable"),
@@ -1245,8 +1245,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"this is a simple string.")
)
- val complexJsonDF = sqlContext.read.json(complexFieldAndType1)
- val compTable = sqlContext.read.json(complexJsonDF.toJSON.rdd)
+ val complexJsonDF = spark.read.json(complexFieldAndType1)
+ val compTable = spark.read.json(complexJsonDF.toJSON.rdd)
compTable.registerTempTable("complexTable")
// Access elements of a primitive array.
checkAnswer(
@@ -1316,7 +1316,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = DataSource(
- sqlContext.sparkSession,
+ spark,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
bucketSpec = None,
@@ -1324,7 +1324,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
options = Map("path" -> path)).resolveRelation()
val d2 = DataSource(
- sqlContext.sparkSession,
+ spark,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
bucketSpec = None,
@@ -1345,16 +1345,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
withTempDir { dir =>
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1)
+ val df = spark.read.schema(schemaWithSimpleMap).json(mapType1)
val path = dir.getAbsolutePath
df.write.mode("overwrite").parquet(path)
// order of MapType is not defined
- assert(sqlContext.read.parquet(path).count() == 5)
+ assert(spark.read.parquet(path).count() == 5)
- val df2 = sqlContext.read.json(corruptRecords)
+ val df2 = spark.read.json(corruptRecords)
df2.write.mode("overwrite").parquet(path)
- checkAnswer(sqlContext.read.parquet(path), df2.collect())
+ checkAnswer(spark.read.parquet(path), df2.collect())
}
}
}
@@ -1387,7 +1387,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"col1",
"abd")
- sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
+ spark.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(sql(
@@ -1447,7 +1447,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Row(4.75.toFloat, Seq(false, true)),
new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))
val data =
- Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil
+ Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil
// Data generated by previous versions.
// scalastyle:off
@@ -1462,7 +1462,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// scalastyle:on
// Generate data for the current version.
- val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema)
withTempPath { path =>
df.write.format("json").mode("overwrite").save(path.getCanonicalPath)
@@ -1486,13 +1486,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"Spark 1.4.1",
"Spark 1.5.0",
"Spark 1.5.0",
- "Spark " + sqlContext.sparkContext.version,
- "Spark " + sqlContext.sparkContext.version)
+ "Spark " + spark.sparkContext.version,
+ "Spark " + spark.sparkContext.version)
val expectedResult = col0Values.map { v =>
Row.fromSeq(Seq(v) ++ constantValues)
}
checkAnswer(
- sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath),
+ spark.read.format("json").schema(schema).load(path.getCanonicalPath),
expectedResult
)
}
@@ -1502,16 +1502,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(2)
+ val df = spark.range(2)
df.write.json(path + "/p=1")
df.write.json(path + "/p=2")
- assert(sqlContext.read.json(path).count() === 4)
+ assert(spark.read.json(path).count() === 4)
val extraOptions = Map(
"mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName,
"mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName
)
- assert(sqlContext.read.options(extraOptions).json(path).count() === 2)
+ assert(spark.read.options(extraOptions).json(path).count() === 2)
}
}
@@ -1525,12 +1525,12 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
{
// We need to make sure we can infer the schema.
- val jsonDF = sqlContext.read.json(additionalCorruptRecords)
+ val jsonDF = spark.read.json(additionalCorruptRecords)
assert(jsonDF.schema === schema)
}
{
- val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords)
+ val jsonDF = spark.read.schema(schema).json(additionalCorruptRecords)
jsonDF.registerTempTable("jsonTable")
// In HiveContext, backticks should be used to access columns starting with a underscore.
@@ -1563,7 +1563,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
StructField("a", StructType(
StructField("b", StringType) :: Nil
)) :: Nil)
- val jsonDF = sqlContext.read.schema(schema).json(path)
+ val jsonDF = spark.read.schema(schema).json(path)
assert(jsonDF.count() == 2)
}
}
@@ -1575,7 +1575,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = sqlContext.read.json(path)
+ val jsonDF = spark.read.json(path)
val jsonDir = new File(dir, "json").getCanonicalPath
jsonDF.coalesce(1).write
.format("json")
@@ -1585,7 +1585,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val compressedFiles = new File(jsonDir).listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".json.gz")))
- val jsonCopy = sqlContext.read
+ val jsonCopy = spark.read
.format("json")
.load(jsonDir)
@@ -1611,7 +1611,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = sqlContext.read.json(path)
+ val jsonDF = spark.read.json(path)
val jsonDir = new File(dir, "json").getCanonicalPath
jsonDF.coalesce(1).write
.format("json")
@@ -1622,7 +1622,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val compressedFiles = new File(jsonDir).listFiles()
assert(compressedFiles.exists(!_.getName.endsWith(".json.gz")))
- val jsonCopy = sqlContext.read
+ val jsonCopy = spark.read
.format("json")
.options(extraOptions)
.load(jsonDir)
@@ -1637,7 +1637,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("Casting long as timestamp") {
withTempTable("jsonTable") {
val schema = (new StructType).add("ts", TimestampType)
- val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong)
+ val jsonDF = spark.read.schema(schema).json(timestampAsLong)
jsonDF.registerTempTable("jsonTable")
@@ -1657,8 +1657,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val json = s"""
|{"a": [{$nested}], "b": [{$nested}]}
""".stripMargin
- val rdd = sqlContext.sparkContext.makeRDD(Seq(json))
- val df = sqlContext.read.json(rdd)
+ val rdd = spark.sparkContext.makeRDD(Seq(json))
+ val df = spark.read.json(rdd)
assert(df.schema.size === 2)
df.collect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 2873c6a881..f4a3336643 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -18,13 +18,13 @@
package org.apache.spark.sql.execution.datasources.json
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
private[json] trait TestJsonData {
- protected def sqlContext: SQLContext
+ protected def spark: SparkSession
def primitiveFieldAndType: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@@ -35,7 +35,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def primitiveFieldValueTypeConflict: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -46,14 +46,14 @@ private[json] trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
def jsonNullStruct: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
def complexFieldValueTypeConflict: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@@ -64,14 +64,14 @@ private[json] trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
def arrayElementTypeConflict: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
def missingFields: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
@@ -79,7 +79,7 @@ private[json] trait TestJsonData {
"""{"e":"str"}""" :: Nil)
def complexFieldAndType1: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@@ -95,7 +95,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def complexFieldAndType2: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@@ -149,7 +149,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def mapType1: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
@@ -157,7 +157,7 @@ private[json] trait TestJsonData {
"""{"map": {"e": null}}""" :: Nil)
def mapType2: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -166,21 +166,21 @@ private[json] trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)
def nullsInArrays: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
def jsonArray: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
def corruptRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@@ -189,7 +189,7 @@ private[json] trait TestJsonData {
"""]""" :: Nil)
def additionalCorruptRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"dummy":"test"}""" ::
"""[1,2,3]""" ::
"""":"test", "a":1}""" ::
@@ -197,7 +197,7 @@ private[json] trait TestJsonData {
""" ","ian":"test"}""" :: Nil)
def emptyRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
@@ -206,23 +206,23 @@ private[json] trait TestJsonData {
"""]""" :: Nil)
def timestampAsLong: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"ts":1451732645}""" :: Nil)
def arrayAndStructRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"""{"a": {"b": 1}}""" ::
"""{"a": []}""" :: Nil)
def floatingValueRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil)
def bigIntegerRecords: RDD[String] =
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil)
- lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
+ lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil)
- def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())
+ def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
index f98ea8c5ae..6509e04e85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
@@ -67,7 +67,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
Row(
i % 2 == 0,
i,
@@ -114,7 +114,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
if (i % 3 == 0) {
Row.apply(Seq.fill(7)(null): _*)
} else {
@@ -155,7 +155,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
Row(
Seq.tabulate(3)(i => s"val_$i"),
if (i % 3 == 0) null else Seq.tabulate(3)(identity))
@@ -182,7 +182,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j))
})
}
@@ -205,7 +205,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap)
})
}
@@ -221,7 +221,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
logParquetSchema(path)
- checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(path), (0 until 10).map { i =>
Row(
Seq.tabulate(3)(n => s"arr_${i + n}"),
Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap,
@@ -267,7 +267,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared
}
}
- checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES"))
+ checkAnswer(spark.read.parquet(path).filter('suit === "SPADES"), Row("SPADES"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index 45cc6810d4..57cd70e191 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -38,7 +38,7 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
}
protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
- val hadoopConf = sqlContext.sessionState.newHadoopConf()
+ val hadoopConf = spark.sessionState.newHadoopConf()
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(hadoopConf)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 65635e3c06..45fd6a5d80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -304,7 +304,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "part = 1" filter gets pushed down, this query will throw an exception since
// "part" is not a valid column in the actual Parquet file
checkAnswer(
- sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"),
+ spark.read.parquet(dir.getCanonicalPath).filter("part = 1"),
(1 to 3).map(i => Row(i, i.toString, 1)))
}
}
@@ -321,7 +321,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "part = 1" filter gets pushed down, this query will throw an exception since
// "part" is not a valid column in the actual Parquet file
checkAnswer(
- sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"),
+ spark.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"),
(2 to 3).map(i => Row(i, i.toString, 1)))
}
}
@@ -339,7 +339,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty,
// this query will throw an exception since the project from combinedFilter expect
// two projection while the
- val df1 = sqlContext.read.parquet(dir.getCanonicalPath)
+ val df1 = spark.read.parquet(dir.getCanonicalPath)
assert(df1.filter("a > 1 or b < 2").count() == 2)
}
@@ -358,7 +358,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// test the generate new projection case
// when projects != partitionAndNormalColumnProjs
- val df1 = sqlContext.read.parquet(dir.getCanonicalPath)
+ val df1 = spark.read.parquet(dir.getCanonicalPath)
checkAnswer(
df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"),
@@ -381,7 +381,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "c = 1" filter gets pushed down, this query will throw an exception which
// Parquet emits. This is a Parquet issue (PARQUET-389).
- val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
+ val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
checkAnswer(
df,
Row(1, "1", null))
@@ -394,7 +394,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
df.write.parquet(pathThree)
// We will remove the temporary metadata when writing Parquet file.
- val schema = sqlContext.read.parquet(pathThree).schema
+ val schema = spark.read.parquet(pathThree).schema
assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
val pathFour = s"${dir.getCanonicalPath}/table4"
@@ -407,7 +407,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "s.c = 1" filter gets pushed down, this query will throw an exception which
// Parquet emits.
- val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1")
+ val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1")
.selectExpr("s")
checkAnswer(dfStruct3, Row(Row(null, 1)))
@@ -420,7 +420,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
dfStruct3.write.parquet(pathSix)
// We will remove the temporary metadata when writing Parquet file.
- val forPathSix = sqlContext.read.parquet(pathSix).schema
+ val forPathSix = spark.read.parquet(pathSix).schema
assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
// sanity test: make sure optional metadata field is not wrongly set.
@@ -429,7 +429,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
val pathEight = s"${dir.getCanonicalPath}/table8"
(4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight)
- val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b")
+ val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b")
checkAnswer(
df2,
Row(1, "1"))
@@ -449,7 +449,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/part=1"
(1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path)
- val df = sqlContext.read.parquet(path).filter("a = 2")
+ val df = spark.read.parquet(path).filter("a = 2")
// The result should be single row.
// When a filter is pushed to Parquet, Parquet can apply it to every row.
@@ -470,11 +470,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
(1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path)
checkAnswer(
- sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"),
+ spark.read.parquet(path).where("not (a = 2) or not(b in ('1'))"),
(1 to 5).map(i => Row(i, (i % 2).toString)))
checkAnswer(
- sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"),
+ spark.read.parquet(path).where("not (a = 2 and b in ('1'))"),
(1 to 5).map(i => Row(i, (i % 2).toString)))
}
}
@@ -527,19 +527,19 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// When a filter is pushed to Parquet, Parquet can apply it to every row.
// So, we can check the number of rows returned from the Parquet
// to make sure our filter pushdown work.
- val df = sqlContext.read.parquet(path).where("b in (0,2)")
+ val df = spark.read.parquet(path).where("b in (0,2)")
assert(stripSparkFilter(df).count == 3)
- val df1 = sqlContext.read.parquet(path).where("not (b in (1))")
+ val df1 = spark.read.parquet(path).where("not (b in (1))")
assert(stripSparkFilter(df1).count == 3)
- val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)")
+ val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)")
assert(stripSparkFilter(df2).count == 2)
- val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)")
+ val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)")
assert(stripSparkFilter(df3).count == 4)
- val df4 = sqlContext.read.parquet(path).where("not (a <= 2)")
+ val df4 = spark.read.parquet(path).where("not (a <= 2)")
assert(stripSparkFilter(df4).count == 3)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 32fe5ba127..d0107aae5a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -113,7 +113,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { location =>
val path = new Path(location.getCanonicalPath)
- val conf = sqlContext.sessionState.newHadoopConf()
+ val conf = spark.sessionState.newHadoopConf()
writeMetadata(parquetSchema, path, conf)
readParquetFile(path.toString)(df => {
val sparkTypes = df.schema.map(_.dataType)
@@ -132,7 +132,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
testStandardAndLegacyModes("fixed-length decimals") {
def makeDecimalRDD(decimal: DecimalType): DataFrame = {
- sqlContext
+ spark
.range(1000)
// Parquet doesn't allow column names with spaces, have to add an alias here.
// Minus 500 here so that negative decimals are also tested.
@@ -250,10 +250,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { location =>
val path = new Path(location.getCanonicalPath)
- val conf = sqlContext.sessionState.newHadoopConf()
+ val conf = spark.sessionState.newHadoopConf()
writeMetadata(parquetSchema, path, conf)
val errorMessage = intercept[Throwable] {
- sqlContext.read.parquet(path.toString).printSchema()
+ spark.read.parquet(path.toString).printSchema()
}.toString
assert(errorMessage.contains("Parquet type not supported"))
}
@@ -271,15 +271,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { location =>
val path = new Path(location.getCanonicalPath)
- val conf = sqlContext.sessionState.newHadoopConf()
+ val conf = spark.sessionState.newHadoopConf()
writeMetadata(parquetSchema, path, conf)
- val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType)
+ val sparkTypes = spark.read.parquet(path.toString).schema.map(_.dataType)
assert(sparkTypes === expectedSparkTypes)
}
}
test("compression codec") {
- val hadoopConf = sqlContext.sessionState.newHadoopConf()
+ val hadoopConf = spark.sessionState.newHadoopConf()
def compressionCodecFor(path: String, codecName: String): String = {
val codecs = for {
footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf)
@@ -296,7 +296,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
def checkCompressionCodec(codec: CompressionCodecName): Unit = {
withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) {
withParquetFile(data) { path =>
- assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) {
+ assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) {
compressionCodecFor(path, codec.name())
}
}
@@ -304,7 +304,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
// Checks default compression codec
- checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec))
+ checkCompressionCodec(
+ CompressionCodecName.fromConf(spark.conf.get(SQLConf.PARQUET_COMPRESSION)))
checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
checkCompressionCodec(CompressionCodecName.GZIP)
@@ -351,7 +352,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
test("write metadata") {
- val hadoopConf = sqlContext.sessionState.newHadoopConf()
+ val hadoopConf = spark.sessionState.newHadoopConf()
withTempPath { file =>
val path = new Path(file.toURI.toString)
val fs = FileSystem.getLocal(hadoopConf)
@@ -433,7 +434,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { location =>
val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString)
val path = new Path(location.getCanonicalPath)
- val conf = sqlContext.sessionState.newHadoopConf()
+ val conf = spark.sessionState.newHadoopConf()
writeMetadata(parquetSchema, path, conf, extraMetadata)
readParquetFile(path.toString) { df =>
@@ -455,7 +456,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
)
withTempPath { dir =>
val message = intercept[SparkException] {
- sqlContext.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath)
+ spark.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath)
}.getCause.getMessage
assert(message === "Intentional exception for testing purposes")
}
@@ -465,10 +466,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
// In 1.3.0, save to fs other than file: without configuring core-site.xml would get:
// IllegalArgumentException: Wrong FS: hdfs://..., expected: file:///
intercept[Throwable] {
- sqlContext.read.parquet("file:///nonexistent")
+ spark.read.parquet("file:///nonexistent")
}
val errorMessage = intercept[Throwable] {
- sqlContext.read.parquet("hdfs://nonexistent")
+ spark.read.parquet("hdfs://nonexistent")
}.toString
assert(errorMessage.contains("UnknownHostException"))
}
@@ -486,14 +487,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
withTempPath { dir =>
val m1 = intercept[SparkException] {
- sqlContext.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath)
+ spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath)
}.getCause.getMessage
assert(m1.contains("Intentional exception for testing purposes"))
}
withTempPath { dir =>
val m2 = intercept[SparkException] {
- val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1)
+ val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1)
df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath)
}.getCause.getMessage
assert(m2.contains("Intentional exception for testing purposes"))
@@ -512,11 +513,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
ParquetOutputFormat.ENABLE_DICTIONARY -> "true"
)
- val hadoopConf = sqlContext.sessionState.newHadoopConfWithOptions(extraOptions)
+ val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions)
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/part-r-0.parquet"
- sqlContext.range(1 << 16).selectExpr("(id % 4) AS i")
+ spark.range(1 << 16).selectExpr("(id % 4) AS i")
.coalesce(1).write.options(extraOptions).mode("overwrite").parquet(path)
val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head
@@ -531,7 +532,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("null and non-null strings") {
// Create a dataset where the first values are NULL and then some non-null values. The
// number of non-nulls needs to be bigger than the ParquetReader batch size.
- val data: Dataset[String] = sqlContext.range(200).map (i =>
+ val data: Dataset[String] = spark.range(200).map (i =>
if (i < 150) null
else "a"
)
@@ -554,7 +555,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
readResourceParquetFile("dec-in-i32.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+ spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
}
}
}
@@ -565,7 +566,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
readResourceParquetFile("dec-in-i64.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+ spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
}
}
}
@@ -576,7 +577,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
checkAnswer(
// Decimal column in this file is encoded using plain dictionary
readResourceParquetFile("dec-in-fixed-len.parquet"),
- sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
+ spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec))
}
}
}
@@ -589,7 +590,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
var hash2: Int = 0
(false :: true :: Nil).foreach { v =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) {
- val df = sqlContext.read.parquet(dir.getCanonicalPath)
+ val df = spark.read.parquet(dir.getCanonicalPath)
val rows = df.queryExecution.toRdd.map(_.copy()).collect()
val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow])
if (!v) {
@@ -607,7 +608,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("VectorizedParquetRecordReader - direct path read") {
val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString))
withTempPath { dir =>
- sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath)
+ spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath)
val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0);
{
val reader = new VectorizedParquetRecordReader
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala
index 83b65fb419..9dc56292c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala
@@ -81,7 +81,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS
logParquetSchema(protobufStylePath)
checkAnswer(
- sqlContext.read.parquet(dir.getCanonicalPath),
+ spark.read.parquet(dir.getCanonicalPath),
Seq(
Row(Seq(0, 1)),
Row(Seq(2, 3))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index b4d35be05d..8707e13461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -400,7 +400,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
// Introduce _temporary dir to the base dir the robustness of the schema discovery process.
new File(base.getCanonicalPath, "_temporary").mkdir()
- sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
+ spark.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -484,7 +484,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
+ spark.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -532,7 +532,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
+ val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -572,7 +572,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
+ val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -604,7 +604,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
(1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"),
makePartitionDir(base, defaultPartitionName, "pi" -> 2))
- sqlContext
+ spark
.read
.option("mergeSchema", "true")
.format("parquet")
@@ -622,7 +622,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("SPARK-7749 Non-partitioned table should have empty partition spec") {
withTempPath { dir =>
(1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath)
- val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution
+ val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution
queryExecution.analyzed.collectFirst {
case LogicalRelation(relation: HadoopFsRelation, _, _) =>
assert(relation.partitionSpec === PartitionSpec.emptySpec)
@@ -636,7 +636,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
withTempPath { dir =>
val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s")
df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath)
- checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect())
+ checkAnswer(spark.read.parquet(dir.getCanonicalPath), df.collect())
}
}
@@ -676,12 +676,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
}
val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
- val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+ val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
withTempPath { dir =>
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
val fields = schema.map(f => Column(f.name).cast(f.dataType))
- checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row)
+ checkAnswer(spark.read.load(dir.toString).select(fields: _*), row)
}
}
@@ -697,7 +697,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store"))
Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
- checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df)
+ checkAnswer(spark.read.format("parquet").load(dir.getCanonicalPath), df)
}
}
@@ -714,7 +714,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS"))
Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
- checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df)
+ checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df)
}
withTempPath { dir =>
@@ -731,7 +731,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS"))
Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
- checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df)
+ checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df)
}
}
@@ -746,7 +746,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
.save(tablePath.getCanonicalPath)
val twoPartitionsDF =
- sqlContext
+ spark
.read
.option("basePath", tablePath.getCanonicalPath)
.parquet(
@@ -756,7 +756,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
checkAnswer(twoPartitionsDF, df.filter("b != 3"))
intercept[AssertionError] {
- sqlContext
+ spark
.read
.parquet(
s"${tablePath.getCanonicalPath}/b=1",
@@ -829,7 +829,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS"))
Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS"))
Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS"))
- checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df)
+ checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df)
}
}
}
@@ -884,9 +884,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
withTempPath { dir =>
withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") {
val path = dir.getCanonicalPath
- val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1)
+ val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1)
df.write.partitionBy("b", "c").parquet(path)
- checkAnswer(sqlContext.read.parquet(path), df)
+ checkAnswer(spark.read.parquet(path), df)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index f1e9726c38..f9f9f80352 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -46,24 +46,24 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
// Query appends, don't test with both read modes.
withParquetTable(data, "t", false) {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.dropTable(
+ spark.sessionState.catalog.dropTable(
TableIdentifier("tmp"), ignoreIfNotExists = true)
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ spark.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
+ checkAnswer(spark.table("t"), data.map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.dropTable(
+ spark.sessionState.catalog.dropTable(
TableIdentifier("tmp"), ignoreIfNotExists = true)
}
@@ -128,9 +128,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
- val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema)
+ val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
- val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+ val df2 = spark.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
}
}
@@ -139,12 +139,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
// delete summary files, so if we don't merge part-files, one column will not be included.
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
- assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+ assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -163,9 +163,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
- assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+ spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -181,19 +181,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
+ spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
// Disables the global SQL option for schema merging
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
assertResult(2) {
// Disables schema merging via data source option
- sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
+ spark.read.option("mergeSchema", "false").parquet(basePath).columns.length
}
assertResult(3) {
// Enables schema merging via data source option
- sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
+ spark.read.option("mergeSchema", "true").parquet(basePath).columns.length
}
}
}
@@ -204,10 +204,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
val basePath = dir.getCanonicalPath
val schema = StructType(Array(StructField("name", DecimalType(10, 5), false)))
val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45"))))
- val df = sqlContext.createDataFrame(rowRDD, schema)
+ val df = spark.createDataFrame(rowRDD, schema)
df.write.parquet(basePath)
- val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0)
+ val decimal = spark.read.parquet(basePath).first().getDecimal(0)
assert(Decimal("67123.45") === Decimal(decimal))
}
}
@@ -227,7 +227,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") {
checkAnswer(
- sqlContext.read.option("mergeSchema", "true").parquet(path),
+ spark.read.option("mergeSchema", "true").parquet(path),
Seq(
Row(Row(1, 1, null)),
Row(Row(2, 2, null)),
@@ -240,7 +240,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-10301 requested schema clipping - same schema") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
+ val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
df.write.parquet(path)
val userDefinedSchema =
@@ -253,7 +253,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(0L, 1L)))
}
}
@@ -261,12 +261,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-11997 parquet with null partition values") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext.range(1, 3)
+ spark.range(1, 3)
.selectExpr("if(id % 2 = 0, null, id) AS n", "id")
.write.partitionBy("n").parquet(path)
checkAnswer(
- sqlContext.read.parquet(path).filter("n is null"),
+ spark.read.parquet(path).filter("n is null"),
Row(2, null))
}
}
@@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
+ val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
df.write.parquet(path)
val userDefinedSchema =
@@ -288,7 +288,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(null, null)))
}
}
@@ -296,7 +296,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-10301 requested schema clipping - requested schema contains physical schema") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
+ val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1)
df.write.parquet(path)
val userDefinedSchema =
@@ -311,13 +311,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(0L, 1L, null, null)))
}
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1)
+ val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1)
df.write.parquet(path)
val userDefinedSchema =
@@ -332,7 +332,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(0L, null, null, 3L)))
}
}
@@ -340,7 +340,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-10301 requested schema clipping - physical schema contains requested schema") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext
+ val df = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s")
.coalesce(1)
@@ -357,13 +357,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(0L, 1L)))
}
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext
+ val df = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s")
.coalesce(1)
@@ -380,7 +380,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(0L, 3L)))
}
}
@@ -388,7 +388,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext
+ val df = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s")
.coalesce(1)
@@ -406,7 +406,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(1L, 2L, null)))
}
}
@@ -415,7 +415,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext
+ val df = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s")
.coalesce(1)
@@ -436,7 +436,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(Seq(Row(0, null)))))
}
}
@@ -445,12 +445,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df1 = sqlContext
+ val df1 = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s")
.coalesce(1)
- val df2 = sqlContext
+ val df2 = spark
.range(1, 2)
.selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s")
.coalesce(1)
@@ -467,7 +467,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Seq(
Row(Row(0, 1, null)),
Row(Row(null, 2, 4))))
@@ -478,12 +478,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df1 = sqlContext
+ val df1 = spark
.range(1)
.selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s")
.coalesce(1)
- val df2 = sqlContext
+ val df2 = spark
.range(1, 2)
.selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s")
.coalesce(1)
@@ -492,7 +492,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
df2.write.mode(SaveMode.Append).parquet(path)
checkAnswer(
- sqlContext
+ spark
.read
.option("mergeSchema", "true")
.parquet(path)
@@ -507,7 +507,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext
+ val df = spark
.range(1)
.selectExpr(
"""NAMED_STRUCT(
@@ -532,7 +532,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
nullable = true)
checkAnswer(
- sqlContext.read.schema(userDefinedSchema).parquet(path),
+ spark.read.schema(userDefinedSchema).parquet(path),
Row(Row(NestedStruct(1, 2L, 3.5D))))
}
}
@@ -585,9 +585,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
+ val df = spark.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
df.write.mode(SaveMode.Overwrite).parquet(path)
- checkAnswer(sqlContext.read.parquet(path), df)
+ checkAnswer(spark.read.parquet(path), df)
}
}
@@ -595,11 +595,11 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withSQLConf("spark.sql.codegen.maxFields" -> "100") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*)
+ val df = spark.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*)
df.write.mode(SaveMode.Overwrite).parquet(path)
// donot return batch, because whole stage codegen is disabled for wide table (>200 columns)
- val df2 = sqlContext.read.parquet(path)
+ val df2 = spark.read.parquet(path)
assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty,
"Should not return batch")
checkAnswer(df2, df)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
index cef541f044..373d3a3a0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
@@ -21,9 +21,9 @@ import java.io.File
import scala.collection.JavaConverters._
import scala.util.Try
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkConf
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.util.{Benchmark, Utils}
/**
@@ -34,12 +34,16 @@ import org.apache.spark.util.{Benchmark, Utils}
object ParquetReadBenchmark {
val conf = new SparkConf()
conf.set("spark.sql.parquet.compression.codec", "snappy")
- val sc = new SparkContext("local[1]", "test-sql-context", conf)
- val sqlContext = new SQLContext(sc)
+
+ val spark = SparkSession.builder
+ .master("local[1]")
+ .appName("test-sql-context")
+ .config(conf)
+ .getOrCreate()
// Set default configs. Individual cases will change them if necessary.
- sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
- sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
+ spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
+ spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
def withTempPath(f: File => Unit): Unit = {
val path = Utils.createTempDir()
@@ -48,17 +52,17 @@ object ParquetReadBenchmark {
}
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
- try f finally tableNames.foreach(sqlContext.dropTempTable)
+ try f finally tableNames.foreach(spark.catalog.dropTempTable)
}
def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
- (keys, values).zipped.foreach(sqlContext.conf.setConfString)
+ val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
+ (keys, values).zipped.foreach(spark.conf.set)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
- case (key, None) => sqlContext.conf.unsetConf(key)
+ case (key, Some(value)) => spark.conf.set(key, value)
+ case (key, None) => spark.conf.unset(key)
}
}
}
@@ -71,18 +75,18 @@ object ParquetReadBenchmark {
withTempPath { dir =>
withTempTable("t1", "tempTable") {
- sqlContext.range(values).registerTempTable("t1")
- sqlContext.sql("select cast(id as INT) as id from t1")
+ spark.range(values).registerTempTable("t1")
+ spark.sql("select cast(id as INT) as id from t1")
.write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
sqlBenchmark.addCase("SQL Parquet Vectorized") { iter =>
- sqlContext.sql("select sum(id) from tempTable").collect()
+ spark.sql("select sum(id) from tempTable").collect()
}
sqlBenchmark.addCase("SQL Parquet MR") { iter =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- sqlContext.sql("select sum(id) from tempTable").collect()
+ spark.sql("select sum(id) from tempTable").collect()
}
}
@@ -155,20 +159,20 @@ object ParquetReadBenchmark {
def intStringScanBenchmark(values: Int): Unit = {
withTempPath { dir =>
withTempTable("t1", "tempTable") {
- sqlContext.range(values).registerTempTable("t1")
- sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1")
+ spark.range(values).registerTempTable("t1")
+ spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1")
.write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
val benchmark = new Benchmark("Int and String Scan", values)
benchmark.addCase("SQL Parquet Vectorized") { iter =>
- sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect
+ spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect
}
benchmark.addCase("SQL Parquet MR") { iter =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect
+ spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect
}
}
@@ -189,20 +193,20 @@ object ParquetReadBenchmark {
def stringDictionaryScanBenchmark(values: Int): Unit = {
withTempPath { dir =>
withTempTable("t1", "tempTable") {
- sqlContext.range(values).registerTempTable("t1")
- sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1")
+ spark.range(values).registerTempTable("t1")
+ spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1")
.write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
val benchmark = new Benchmark("String Dictionary", values)
benchmark.addCase("SQL Parquet Vectorized") { iter =>
- sqlContext.sql("select sum(length(c1)) from tempTable").collect
+ spark.sql("select sum(length(c1)) from tempTable").collect
}
benchmark.addCase("SQL Parquet MR") { iter =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
- sqlContext.sql("select sum(length(c1)) from tempTable").collect
+ spark.sql("select sum(length(c1)) from tempTable").collect
}
}
@@ -221,23 +225,23 @@ object ParquetReadBenchmark {
def partitionTableScanBenchmark(values: Int): Unit = {
withTempPath { dir =>
withTempTable("t1", "tempTable") {
- sqlContext.range(values).registerTempTable("t1")
- sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1")
+ spark.range(values).registerTempTable("t1")
+ spark.sql("select id % 2 as p, cast(id as INT) as id from t1")
.write.partitionBy("p").parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
val benchmark = new Benchmark("Partitioned Table", values)
benchmark.addCase("Read data column") { iter =>
- sqlContext.sql("select sum(id) from tempTable").collect
+ spark.sql("select sum(id) from tempTable").collect
}
benchmark.addCase("Read partition column") { iter =>
- sqlContext.sql("select sum(p) from tempTable").collect
+ spark.sql("select sum(p) from tempTable").collect
}
benchmark.addCase("Read both columns") { iter =>
- sqlContext.sql("select sum(p), sum(id) from tempTable").collect
+ spark.sql("select sum(p), sum(id) from tempTable").collect
}
/*
@@ -256,16 +260,16 @@ object ParquetReadBenchmark {
def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = {
withTempPath { dir =>
withTempTable("t1", "tempTable") {
- sqlContext.range(values).registerTempTable("t1")
- sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " +
+ spark.range(values).registerTempTable("t1")
+ spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " +
s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1")
.write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+ spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
val benchmark = new Benchmark("String with Nulls Scan", values)
benchmark.addCase("SQL Parquet Vectorized") { iter =>
- sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " +
+ spark.sql("select sum(length(c2)) from tempTable where c1 is " +
"not NULL and c2 is not NULL").collect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 90e3d50714..c43b142de2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -453,11 +453,11 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
test("schema merging failure error message") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext.range(3).write.parquet(s"$path/p=1")
- sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
+ spark.range(3).write.parquet(s"$path/p=1")
+ spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
val message = intercept[SparkException] {
- sqlContext.read.option("mergeSchema", "true").parquet(path).schema
+ spark.read.option("mergeSchema", "true").parquet(path).schema
}.getMessage
assert(message.contains("Failed merging schema of file"))
@@ -466,13 +466,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
// test for second merging (after read Parquet schema in parallel done)
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext.range(3).write.parquet(s"$path/p=1")
- sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
+ spark.range(3).write.parquet(s"$path/p=1")
+ spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
- sqlContext.sparkContext.conf.set("spark.default.parallelism", "20")
+ spark.sparkContext.conf.set("spark.default.parallelism", "20")
val message = intercept[SparkException] {
- sqlContext.read.option("mergeSchema", "true").parquet(path).schema
+ spark.read.option("mergeSchema", "true").parquet(path).schema
}.getMessage
assert(message.contains("Failed merging schema:"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index e8c524e9e5..b5fc51603e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -52,7 +52,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(true :: false :: Nil).foreach { vectorized =>
if (!vectorized || testVectorized) {
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
- f(sqlContext.read.parquet(path.toString))
+ f(spark.read.parquet(path.toString))
}
}
}
@@ -66,7 +66,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
+ spark.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -90,14 +90,14 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
withParquetDataFrame(data, testVectorized) { df =>
- sqlContext.registerDataFrameAsTable(df, tableName)
+ spark.wrapped.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
- sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+ spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
@@ -173,6 +173,6 @@ private[sql] trait ParquetTest extends SQLTestUtils {
protected def readResourceParquetFile(name: String): DataFrame = {
val url = Thread.currentThread().getContextClassLoader.getResource(name)
- sqlContext.read.parquet(url.toString)
+ spark.read.parquet(url.toString)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index 88a3d878f9..ff5706999a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -32,7 +32,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
|${readParquetSchema(parquetFilePath.toString)}
""".stripMargin)
- checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
+ checkAnswer(spark.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")
val nonNullablePrimitiveValues = Seq(
@@ -139,7 +139,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
logParquetSchema(path)
checkAnswer(
- sqlContext.read.parquet(path),
+ spark.read.parquet(path),
Seq(
Row(Seq(Seq(0, 1), Seq(2, 3))),
Row(Seq(Seq(4, 5), Seq(6, 7)))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala
index fd56265297..08b7eb3cf7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.execution.datasources.parquet
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkConf
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.util.Benchmark
@@ -36,9 +36,10 @@ object TPCDSBenchmark {
conf.set("spark.driver.memory", "3g")
conf.set("spark.executor.memory", "3g")
conf.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString)
+ conf.setMaster("local[1]")
+ conf.setAppName("test-sql-context")
- val sc = new SparkContext("local[1]", "test-sql-context", conf)
- val sqlContext = new SQLContext(sc)
+ val spark = SparkSession.builder.config(conf).getOrCreate()
// These queries a subset of the TPCDS benchmark queries and are taken from
// https://github.com/databricks/spark-sql-perf/blob/master/src/main/scala/com/databricks/spark/
@@ -1186,8 +1187,8 @@ object TPCDSBenchmark {
def setupTables(dataLocation: String): Map[String, Long] = {
tables.map { tableName =>
- sqlContext.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName)
- tableName -> sqlContext.table(tableName).count()
+ spark.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName)
+ tableName -> spark.table(tableName).count()
}.toMap
}
@@ -1195,18 +1196,18 @@ object TPCDSBenchmark {
require(dataLocation.nonEmpty,
"please modify the value of dataLocation to point to your local TPCDS data")
val tableSizes = setupTables(dataLocation)
- sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
- sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
+ spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
+ spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
tpcds.filter(q => q._1 != "").foreach {
case (name: String, query: String) =>
- val numRows = sqlContext.sql(query).queryExecution.logical.map {
+ val numRows = spark.sql(query).queryExecution.logical.map {
case ur@UnresolvedRelation(t: TableIdentifier, _) =>
tableSizes.getOrElse(t.table, throw new RuntimeException(s"${t.table} not found."))
case _ => 0L
}.sum
val benchmark = new Benchmark("TPCDS Snappy (scale = 5)", numRows, 5)
benchmark.addCase(name) { i =>
- sqlContext.sql(query).collect()
+ spark.sql(query).collect()
}
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index 923c0b350e..f61fce5d41 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -33,20 +33,20 @@ import org.apache.spark.util.Utils
class TextSuite extends QueryTest with SharedSQLContext {
test("reading text file") {
- verifyFrame(sqlContext.read.format("text").load(testFile))
+ verifyFrame(spark.read.format("text").load(testFile))
}
test("SQLContext.read.text() API") {
- verifyFrame(sqlContext.read.text(testFile).toDF())
+ verifyFrame(spark.read.text(testFile).toDF())
}
test("SPARK-12562 verify write.text() can handle column name beyond `value`") {
- val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")
+ val df = spark.read.text(testFile).withColumnRenamed("value", "adwrasdf")
val tempFile = Utils.createTempDir()
tempFile.delete()
df.write.text(tempFile.getCanonicalPath)
- verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF())
+ verifyFrame(spark.read.text(tempFile.getCanonicalPath).toDF())
Utils.deleteRecursively(tempFile)
}
@@ -55,18 +55,18 @@ class TextSuite extends QueryTest with SharedSQLContext {
val tempFile = Utils.createTempDir()
tempFile.delete()
- val df = sqlContext.range(2)
+ val df = spark.range(2)
intercept[AnalysisException] {
df.write.text(tempFile.getCanonicalPath)
}
intercept[AnalysisException] {
- sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
+ spark.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
}
}
test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
- val testDf = sqlContext.read.text(testFile)
+ val testDf = spark.read.text(testFile)
val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz")
extensionNameMap.foreach {
case (codecName, extension) =>
@@ -75,7 +75,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
val compressedFiles = new File(tempDirPath).listFiles()
assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension")))
- verifyFrame(sqlContext.read.text(tempDirPath).toDF())
+ verifyFrame(spark.read.text(tempDirPath).toDF())
}
val errMsg = intercept[IllegalArgumentException] {
@@ -95,14 +95,14 @@ class TextSuite extends QueryTest with SharedSQLContext {
"mapreduce.map.output.compress.codec" -> classOf[GzipCodec].getName
)
withTempDir { dir =>
- val testDf = sqlContext.read.text(testFile)
+ val testDf = spark.read.text(testFile)
val tempDir = Utils.createTempDir()
val tempDirPath = tempDir.getAbsolutePath
testDf.write.option("compression", "none")
.options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath)
val compressedFiles = new File(tempDirPath).listFiles()
assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz")))
- verifyFrame(sqlContext.read.options(extraOptions).text(tempDirPath).toDF())
+ verifyFrame(spark.read.options(extraOptions).text(tempDirPath).toDF())
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 8aa0114d98..4fc52c99fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -33,7 +33,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
}
test("debugCodegen") {
- val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan)
+ val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
assert(res.contains("Subtree 1 / 2"))
assert(res.contains("Subtree 2 / 2"))
assert(res.contains("Object[]"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index b9df43d049..730ec43556 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SQLContext}
+import org.apache.spark.sql.{QueryTest, SparkSession}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
@@ -34,7 +34,7 @@ import org.apache.spark.sql.functions._
* without serializing the hashed relation, which does not happen in local mode.
*/
class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
- protected var sqlContext: SQLContext = null
+ protected var spark: SparkSession = null
/**
* Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
@@ -45,26 +45,26 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
.setMaster("local-cluster[2,1,1024]")
.setAppName("testing")
val sc = new SparkContext(conf)
- sqlContext = new SQLContext(sc)
+ spark = SparkSession.builder.getOrCreate()
}
override def afterAll(): Unit = {
- sqlContext.sparkContext.stop()
- sqlContext = null
+ spark.stop()
+ spark = null
}
/**
* Test whether the specified broadcast join updates the peak execution memory accumulator.
*/
private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
- AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) {
- val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
- val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+ AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) {
+ val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+ val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
val plan =
- EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan)
+ EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 2a4a3690f2..7caeb3be54 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -32,7 +32,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.newProductEncoder
import testImplicits.localSeqToDatasetHolder
- private lazy val myUpperCaseData = sqlContext.createDataFrame(
+ private lazy val myUpperCaseData = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
@@ -43,7 +43,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))
- private lazy val myLowerCaseData = sqlContext.createDataFrame(
+ private lazy val myLowerCaseData = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
@@ -99,7 +99,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
boundCondition,
leftPlan,
rightPlan)
- EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin)
+ EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin)
}
def makeShuffledHashJoin(
@@ -113,7 +113,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
- EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin)
+ EnsureRequirements(spark.sessionState.conf).apply(filteredJoin)
}
def makeSortMergeJoin(
@@ -124,7 +124,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
rightPlan: SparkPlan) = {
val sortMergeJoin =
joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan)
- EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin)
+ EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index c26cb8483e..001feb0f2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
- private lazy val left = sqlContext.createDataFrame(
+ private lazy val left = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(2, 100.0),
@@ -42,7 +42,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
- private lazy val right = sqlContext.createDataFrame(
+ private lazy val right = spark.createDataFrame(
sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
@@ -82,7 +82,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(sqlContext.sessionState.conf).apply(
+ EnsureRequirements(spark.sessionState.conf).apply(
ShuffledHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
@@ -115,7 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(sqlContext.sessionState.conf).apply(
+ EnsureRequirements(spark.sessionState.conf).apply(
SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index d41e88a0aa..1b82769428 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -71,21 +71,21 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
- val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+ val previousExecutionIds = spark.listener.executionIdToData.keySet
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
df.collect()
}
sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+ val jobs = spark.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
- val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+ val metricValues = spark.listener.getExecutionMetrics(executionId)
val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(
df.queryExecution.executedPlan)).allNodes.filter { node =>
expectedMetrics.contains(node.id)
@@ -128,7 +128,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
// Assume the execution plan is
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1))
// TODO: update metrics in generated operators
- val ds = sqlContext.range(10).filter('id < 5)
+ val ds = spark.range(10).filter('id < 5)
testSparkPlanMetrics(ds.toDF(), 1, Map.empty)
}
@@ -157,7 +157,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("Sort metrics") {
// Assume the execution plan is
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
- val ds = sqlContext.range(10).sort('id)
+ val ds = spark.range(10).sort('id)
testSparkPlanMetrics(ds.toDF(), 2, Map.empty)
}
@@ -169,7 +169,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = sqlContext.sql(
+ val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a")
testSparkPlanMetrics(df, 1, Map(
0L -> ("SortMergeJoin", Map(
@@ -187,7 +187,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = sqlContext.sql(
+ val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a")
testSparkPlanMetrics(df, 1, Map(
0L -> ("SortMergeJoin", Map(
@@ -195,7 +195,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
"number of output rows" -> 8L)))
)
- val df2 = sqlContext.sql(
+ val df2 = spark.sql(
"SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a")
testSparkPlanMetrics(df2, 1, Map(
0L -> ("SortMergeJoin", Map(
@@ -241,7 +241,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = sqlContext.sql(
+ val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
testSparkPlanMetrics(df, 3, Map(
@@ -269,7 +269,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = sqlContext.sql(
+ val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin")
testSparkPlanMetrics(df, 1, Map(
0L -> ("CartesianProduct", Map(
@@ -280,19 +280,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+ val previousExecutionIds = spark.listener.executionIdToData.keySet
// Assume the execution plan is
// PhysicalRDD(nodeId = 0)
person.select('name).write.format("json").save(file.getAbsolutePath)
sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+ val jobs = spark.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= 1)
- val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+ val metricValues = spark.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
assert(metricValues.values.toSeq.exists(_ === "2"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
index 7b413dda1e..a7b2cfe7d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
@@ -217,7 +217,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3",
SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") {
withFileStreamSinkLog { sinkLog =>
- val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sessionState.newHadoopConf())
+ val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf())
def listBatchFiles(): Set[String] = {
fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName =>
@@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = {
withTempDir { file =>
- val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath)
+ val sinkLog = new FileStreamSinkLog(spark, file.getCanonicalPath)
f(sinkLog)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index 5f92c5bb9b..ef2b479a56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -59,63 +59,63 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
test("HDFSMetadataLog: basic") {
withTempDir { temp =>
val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
- val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.getLatest() === Some(0 -> "batch0"))
assert(metadataLog.get(0) === Some("batch0"))
assert(metadataLog.getLatest() === Some(0 -> "batch0"))
- assert(metadataLog.get(None, 0) === Array(0 -> "batch0"))
+ assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0"))
assert(metadataLog.add(1, "batch1"))
assert(metadataLog.get(0) === Some("batch0"))
assert(metadataLog.get(1) === Some("batch1"))
assert(metadataLog.getLatest() === Some(1 -> "batch1"))
- assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1"))
+ assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1"))
// Adding the same batch does nothing
metadataLog.add(1, "batch1-duplicated")
assert(metadataLog.get(0) === Some("batch0"))
assert(metadataLog.get(1) === Some("batch1"))
assert(metadataLog.getLatest() === Some(1 -> "batch1"))
- assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1"))
+ assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1"))
}
}
testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") {
- sqlContext.conf.setConfString(
+ spark.conf.set(
s"fs.$scheme.impl",
classOf[FakeFileSystem].getName)
withTempDir { temp =>
- val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp")
+ val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp")
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.getLatest() === Some(0 -> "batch0"))
assert(metadataLog.get(0) === Some("batch0"))
- assert(metadataLog.get(None, 0) === Array(0 -> "batch0"))
+ assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0"))
- val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp")
+ val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp")
assert(metadataLog2.get(0) === Some("batch0"))
assert(metadataLog2.getLatest() === Some(0 -> "batch0"))
- assert(metadataLog2.get(None, 0) === Array(0 -> "batch0"))
+ assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0"))
}
}
test("HDFSMetadataLog: restart") {
withTempDir { temp =>
- val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.add(1, "batch1"))
assert(metadataLog.get(0) === Some("batch0"))
assert(metadataLog.get(1) === Some("batch1"))
assert(metadataLog.getLatest() === Some(1 -> "batch1"))
- assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1"))
+ assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1"))
- val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
+ val metadataLog2 = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
assert(metadataLog2.get(0) === Some("batch0"))
assert(metadataLog2.get(1) === Some("batch1"))
assert(metadataLog2.getLatest() === Some(1 -> "batch1"))
- assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1"))
+ assert(metadataLog2.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1"))
}
}
@@ -127,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
new Thread() {
override def run(): Unit = waiter {
val metadataLog =
- new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
+ new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
try {
var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
nextBatchId += 1
@@ -146,9 +146,10 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
waiter.await(timeout(10.seconds), dismissals(10))
- val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString))
- assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString)))
+ assert(
+ metadataLog.get(None, Some(maxBatchId)) === (0 to maxBatchId).map(i => (i, i.toString)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 6be94eb24f..4fa1754253 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -27,10 +27,11 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.sql.LocalSparkSession._
import org.apache.spark.LocalSparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{CompletionIterator, Utils}
@@ -54,19 +55,18 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
test("versioning and immutability") {
- withSpark(new SparkContext(sparkConf)) { sc =>
- val sqlContext = new SQLContext(sc)
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val rdd1 =
- makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
- sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(
increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
- sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
+ spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -79,30 +79,30 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
def makeStoreRDD(
- sc: SparkContext,
+ spark: SparkSession,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = new SQLContext(sc)
- makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+ implicit val sqlContext = spark.wrapped
+ makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
}
// Generate RDDs and state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
for (i <- 1 to 20) {
- require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+ require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}
// With a new context, try using the earlier state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+ assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21))
}
}
test("usage with iterators - only gets and only puts") {
- withSpark(new SparkContext(sparkConf)) { sc =>
- implicit val sqlContext = new SQLContext(sc)
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+ implicit val sqlContext = spark.wrapped
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
@@ -130,15 +130,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
}
- val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
- sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
- val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
- val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
}
@@ -149,8 +149,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val opId = 0
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- withSpark(new SparkContext(sparkConf)) { sc =>
- implicit val sqlContext = new SQLContext(sc)
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+ implicit val sqlContext = spark.wrapped
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
@@ -159,7 +159,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
- val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
require(rdd.partitions.length === 2)
@@ -178,16 +178,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("distributed test") {
quietly {
- withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
- implicit val sqlContext = new SQLContext(sc)
+
+ withSparkSession(
+ SparkSession.builder
+ .config(sparkConf.setMaster("local-cluster[2, 1, 1024]"))
+ .getOrCreate()) { spark =>
+ implicit val sqlContext = spark.wrapped
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 67e44849ca..9eff42ab2d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -98,7 +98,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
}
- val listener = new SQLListener(sqlContext.sparkContext.conf)
+ val listener = new SQLListener(spark.sparkContext.conf)
val executionId = 0
val df = createTestDataFrame
val accumulatorIds =
@@ -239,7 +239,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
- val listener = new SQLListener(sqlContext.sparkContext.conf)
+ val listener = new SQLListener(spark.sparkContext.conf)
val executionId = 0
val df = createTestDataFrame
listener.onOtherEvent(SparkListenerSQLExecutionStart(
@@ -269,7 +269,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
- val listener = new SQLListener(sqlContext.sparkContext.conf)
+ val listener = new SQLListener(spark.sparkContext.conf)
val executionId = 0
val df = createTestDataFrame
listener.onOtherEvent(SparkListenerSQLExecutionStart(
@@ -310,7 +310,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
- val listener = new SQLListener(sqlContext.sparkContext.conf)
+ val listener = new SQLListener(spark.sparkContext.conf)
val executionId = 0
val df = createTestDataFrame
listener.onOtherEvent(SparkListenerSQLExecutionStart(
@@ -340,16 +340,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("SPARK-11126: no memory leak when running non SQL jobs") {
- val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size
- sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ())
- sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000)
+ val previousStageNumber = spark.listener.stageIdToStageMetrics.size
+ spark.sparkContext.parallelize(1 to 10).foreach(i => ())
+ spark.sparkContext.listenerBus.waitUntilEmpty(10000)
// listener should ignore the non SQL stage
- assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber)
+ assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber)
- sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ())
- sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000)
+ spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ())
+ spark.sparkContext.listenerBus.waitUntilEmpty(10000)
// listener should save the SQL stage
- assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1)
+ assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber + 1)
}
test("SPARK-13055: history listener only tracks SQL metrics") {
@@ -401,8 +401,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
val sc = new SparkContext(conf)
try {
SQLContext.clearSqlListener()
- val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._
+ val spark = new SQLContext(sc)
+ import spark.implicits._
// Run 100 successful executions and 100 failed executions.
// Each execution only has one job and one stage.
for (i <- 0 until 100) {
@@ -418,12 +418,12 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
}
}
sc.listenerBus.waitUntilEmpty(10000)
- assert(sqlContext.listener.getCompletedExecutions.size <= 50)
- assert(sqlContext.listener.getFailedExecutions.size <= 50)
+ assert(spark.listener.getCompletedExecutions.size <= 50)
+ assert(spark.listener.getFailedExecutions.size <= 50)
// 50 for successful executions and 50 for failed executions
- assert(sqlContext.listener.executionIdToData.size <= 100)
- assert(sqlContext.listener.jobIdToExecutionId.size <= 100)
- assert(sqlContext.listener.stageIdToStageMetrics.size <= 100)
+ assert(spark.listener.executionIdToData.size <= 100)
+ assert(spark.listener.jobIdToExecutionId.size <= 100)
+ assert(spark.listener.stageIdToStageMetrics.size <= 100)
} finally {
sc.stop()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 73c2076a30..56f848b9a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -37,13 +37,12 @@ class CatalogSuite
with BeforeAndAfterEach
with SharedSQLContext {
- private def sparkSession: SparkSession = sqlContext.sparkSession
- private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog
+ private def sessionCatalog: SessionCatalog = spark.sessionState.catalog
private val utils = new CatalogTestUtils {
override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat"
override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat"
- override def newEmptyCatalog(): ExternalCatalog = sparkSession.sharedState.externalCatalog
+ override def newEmptyCatalog(): ExternalCatalog = spark.sharedState.externalCatalog
}
private def createDatabase(name: String): Unit = {
@@ -87,8 +86,8 @@ class CatalogSuite
private def testListColumns(tableName: String, dbName: Option[String]): Unit = {
val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, dbName))
val columns = dbName
- .map { db => sparkSession.catalog.listColumns(db, tableName) }
- .getOrElse { sparkSession.catalog.listColumns(tableName) }
+ .map { db => spark.catalog.listColumns(db, tableName) }
+ .getOrElse { spark.catalog.listColumns(tableName) }
assume(tableMetadata.schema.nonEmpty, "bad test")
assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test")
assume(tableMetadata.bucketColumnNames.nonEmpty, "bad test")
@@ -108,85 +107,85 @@ class CatalogSuite
}
test("current database") {
- assert(sparkSession.catalog.currentDatabase == "default")
+ assert(spark.catalog.currentDatabase == "default")
assert(sessionCatalog.getCurrentDatabase == "default")
createDatabase("my_db")
- sparkSession.catalog.setCurrentDatabase("my_db")
- assert(sparkSession.catalog.currentDatabase == "my_db")
+ spark.catalog.setCurrentDatabase("my_db")
+ assert(spark.catalog.currentDatabase == "my_db")
assert(sessionCatalog.getCurrentDatabase == "my_db")
val e = intercept[AnalysisException] {
- sparkSession.catalog.setCurrentDatabase("unknown_db")
+ spark.catalog.setCurrentDatabase("unknown_db")
}
assert(e.getMessage.contains("unknown_db"))
}
test("list databases") {
- assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == Set("default"))
+ assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default"))
createDatabase("my_db1")
createDatabase("my_db2")
- assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet ==
+ assert(spark.catalog.listDatabases().collect().map(_.name).toSet ==
Set("default", "my_db1", "my_db2"))
dropDatabase("my_db1")
- assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet ==
+ assert(spark.catalog.listDatabases().collect().map(_.name).toSet ==
Set("default", "my_db2"))
}
test("list tables") {
- assert(sparkSession.catalog.listTables().collect().isEmpty)
+ assert(spark.catalog.listTables().collect().isEmpty)
createTable("my_table1")
createTable("my_table2")
createTempTable("my_temp_table")
- assert(sparkSession.catalog.listTables().collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables().collect().map(_.name).toSet ==
Set("my_table1", "my_table2", "my_temp_table"))
dropTable("my_table1")
- assert(sparkSession.catalog.listTables().collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables().collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
dropTable("my_temp_table")
- assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == Set("my_table2"))
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2"))
}
test("list tables with database") {
- assert(sparkSession.catalog.listTables("default").collect().isEmpty)
+ assert(spark.catalog.listTables("default").collect().isEmpty)
createDatabase("my_db1")
createDatabase("my_db2")
createTable("my_table1", Some("my_db1"))
createTable("my_table2", Some("my_db2"))
createTempTable("my_temp_table")
- assert(sparkSession.catalog.listTables("default").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("default").collect().map(_.name).toSet ==
Set("my_temp_table"))
- assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet ==
Set("my_table1", "my_temp_table"))
- assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
dropTable("my_table1", Some("my_db1"))
- assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet ==
Set("my_temp_table"))
- assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2", "my_temp_table"))
dropTable("my_temp_table")
- assert(sparkSession.catalog.listTables("default").collect().map(_.name).isEmpty)
- assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).isEmpty)
- assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet ==
+ assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty)
+ assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty)
+ assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet ==
Set("my_table2"))
val e = intercept[AnalysisException] {
- sparkSession.catalog.listTables("unknown_db")
+ spark.catalog.listTables("unknown_db")
}
assert(e.getMessage.contains("unknown_db"))
}
test("list functions") {
assert(Set("+", "current_database", "window").subsetOf(
- sparkSession.catalog.listFunctions().collect().map(_.name).toSet))
+ spark.catalog.listFunctions().collect().map(_.name).toSet))
createFunction("my_func1")
createFunction("my_func2")
createTempFunction("my_temp_func")
- val funcNames1 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet
+ val funcNames1 = spark.catalog.listFunctions().collect().map(_.name).toSet
assert(funcNames1.contains("my_func1"))
assert(funcNames1.contains("my_func2"))
assert(funcNames1.contains("my_temp_func"))
dropFunction("my_func1")
dropTempFunction("my_temp_func")
- val funcNames2 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet
+ val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet
assert(!funcNames2.contains("my_func1"))
assert(funcNames2.contains("my_func2"))
assert(!funcNames2.contains("my_temp_func"))
@@ -194,14 +193,14 @@ class CatalogSuite
test("list functions with database") {
assert(Set("+", "current_database", "window").subsetOf(
- sparkSession.catalog.listFunctions("default").collect().map(_.name).toSet))
+ spark.catalog.listFunctions("default").collect().map(_.name).toSet))
createDatabase("my_db1")
createDatabase("my_db2")
createFunction("my_func1", Some("my_db1"))
createFunction("my_func2", Some("my_db2"))
createTempFunction("my_temp_func")
- val funcNames1 = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet
- val funcNames2 = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet
+ val funcNames1 = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet
+ val funcNames2 = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet
assert(funcNames1.contains("my_func1"))
assert(!funcNames1.contains("my_func2"))
assert(funcNames1.contains("my_temp_func"))
@@ -210,14 +209,14 @@ class CatalogSuite
assert(funcNames2.contains("my_temp_func"))
dropFunction("my_func1", Some("my_db1"))
dropTempFunction("my_temp_func")
- val funcNames1b = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet
- val funcNames2b = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet
+ val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet
+ val funcNames2b = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet
assert(!funcNames1b.contains("my_func1"))
assert(!funcNames1b.contains("my_temp_func"))
assert(funcNames2b.contains("my_func2"))
assert(!funcNames2b.contains("my_temp_func"))
val e = intercept[AnalysisException] {
- sparkSession.catalog.listFunctions("unknown_db")
+ spark.catalog.listFunctions("unknown_db")
}
assert(e.getMessage.contains("unknown_db"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index b87f482941..7ead97bbf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -33,61 +33,61 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
test("programmatic ways of basic setting and getting") {
// Set a conf first.
- sqlContext.setConf(testKey, testVal)
+ spark.conf.set(testKey, testVal)
// Clear the conf.
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
// After clear, only overrideConfs used by unit test should be in the SQLConf.
- assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs)
+ assert(spark.conf.getAll === TestSQLContext.overrideConfs)
- sqlContext.setConf(testKey, testVal)
- assert(sqlContext.getConf(testKey) === testVal)
- assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
- assert(sqlContext.getAllConfs.contains(testKey))
+ spark.conf.set(testKey, testVal)
+ assert(spark.conf.get(testKey) === testVal)
+ assert(spark.conf.get(testKey, testVal + "_") === testVal)
+ assert(spark.conf.getAll.contains(testKey))
// Tests SQLConf as accessed from a SQLContext is mutable after
// the latter is initialized, unlike SparkConf inside a SparkContext.
- assert(sqlContext.getConf(testKey) === testVal)
- assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
- assert(sqlContext.getAllConfs.contains(testKey))
+ assert(spark.conf.get(testKey) === testVal)
+ assert(spark.conf.get(testKey, testVal + "_") === testVal)
+ assert(spark.conf.getAll.contains(testKey))
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
}
test("parse SQL set commands") {
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
sql(s"set $testKey=$testVal")
- assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
- assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
+ assert(spark.conf.get(testKey, testVal + "_") === testVal)
+ assert(spark.conf.get(testKey, testVal + "_") === testVal)
sql("set some.property=20")
- assert(sqlContext.getConf("some.property", "0") === "20")
+ assert(spark.conf.get("some.property", "0") === "20")
sql("set some.property = 40")
- assert(sqlContext.getConf("some.property", "0") === "40")
+ assert(spark.conf.get("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
sql(s"set $key=$vs")
- assert(sqlContext.getConf(key, "0") === vs)
+ assert(spark.conf.get(key, "0") === vs)
sql(s"set $key=")
- assert(sqlContext.getConf(key, "0") === "")
+ assert(spark.conf.get(key, "0") === "")
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
}
test("deprecated property") {
- sqlContext.conf.clear()
- val original = sqlContext.conf.numShufflePartitions
+ spark.wrapped.conf.clear()
+ val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
try{
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
- assert(sqlContext.conf.numShufflePartitions === 10)
+ assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10)
} finally {
sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original")
}
}
test("invalid conf value") {
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
val e = intercept[IllegalArgumentException] {
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
@@ -95,35 +95,35 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") {
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
- assert(sqlContext.conf.targetPostShuffleInputSize === 100)
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
+ assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100)
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k")
- assert(sqlContext.conf.targetPostShuffleInputSize === 1024)
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k")
+ assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1024)
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M")
- assert(sqlContext.conf.targetPostShuffleInputSize === 1048576)
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M")
+ assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1048576)
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g")
- assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824)
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g")
+ assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1073741824)
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1")
- assert(sqlContext.conf.targetPostShuffleInputSize === -1)
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1")
+ assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === -1)
// Test overflow exception
intercept[IllegalArgumentException] {
// This value exceeds Long.MaxValue
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g")
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g")
}
intercept[IllegalArgumentException] {
// This value less than Long.MinValue
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
+ spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
}
- sqlContext.conf.clear()
+ spark.wrapped.conf.clear()
}
test("SparkSession can access configs set in SparkConf") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 47a1017caa..44d1b9ddda 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -337,39 +337,39 @@ class JDBCSuite extends SparkFunSuite
}
test("Basic API") {
- assert(sqlContext.read.jdbc(
+ assert(spark.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
- assert(sqlContext.read.jdbc(
+ assert(spark.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
- sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
+ assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}
test("Partitioning on column that might have null values.") {
assert(
- sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties)
.collect().length === 4)
assert(
- sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties)
.collect().length === 4)
// partitioning on a nullable quoted column
assert(
- sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties)
.collect().length === 4)
}
@@ -429,9 +429,9 @@ class JDBCSuite extends SparkFunSuite
}
test("test DATE types") {
- val rows = sqlContext.read.jdbc(
+ val rows = spark.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
@@ -439,8 +439,8 @@ class JDBCSuite extends SparkFunSuite
}
test("test DATE types in cache") {
- val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
+ spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
@@ -448,7 +448,7 @@ class JDBCSuite extends SparkFunSuite
}
test("test types for null value") {
- val rows = sqlContext.read.jdbc(
+ val rows = spark.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
@@ -495,7 +495,7 @@ class JDBCSuite extends SparkFunSuite
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
- val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
@@ -629,7 +629,7 @@ class JDBCSuite extends SparkFunSuite
// Regression test for bug SPARK-11788
val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");
val date = java.sql.Date.valueOf("1995-01-01")
- val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(0).getAs[java.sql.Timestamp](2)
@@ -639,7 +639,7 @@ class JDBCSuite extends SparkFunSuite
test("test credentials in the properties are not in plan output") {
val df = sql("SELECT * FROM parts")
val explain = ExplainCommand(df.queryExecution.logical, extended = true)
- sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
+ spark.executePlan(explain).executedPlan.executeCollect().foreach {
r => assert(!List("testPass", "testUser").exists(r.toString.contains))
}
// test the JdbcRelation toString output
@@ -649,9 +649,9 @@ class JDBCSuite extends SparkFunSuite
}
test("test credentials in the connection url are not in the plan output") {
- val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
val explain = ExplainCommand(df.queryExecution.logical, extended = true)
- sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
+ spark.executePlan(explain).executedPlan.executeCollect().foreach {
r => assert(!List("testPass", "testUser").exists(r.toString.contains))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e23ee66931..48fa5f9822 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -88,50 +88,50 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
- assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
+ assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
assert(
- 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
+ 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
test("CREATE with overwrite") {
- val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
- val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
- assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
- assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
- val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
- assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
- assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
+ assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
+ assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
- val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
- assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
- assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
+ assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {
- val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
intercept[org.apache.spark.SparkException] {
@@ -141,14 +141,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
test("INSERT to JDBC Datasource") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 92061133cd..754aa32a30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -24,7 +24,7 @@ private[sql] abstract class DataSourceTest extends QueryTest {
// We want to test some edge cases.
protected lazy val caseInsensitiveContext: SQLContext = {
- val ctx = new SQLContext(sqlContext.sparkContext)
+ val ctx = new SQLContext(spark.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index a9b1970a7c..a2decadbe0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val df = sqlContext.range(100).select($"id", lit(1).as("data"))
+ val df = spark.range(100).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- sqlContext.read.load(path.getCanonicalPath),
+ spark.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val base = sqlContext.range(100)
+ val base = spark.range(100)
val df = base.union(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- sqlContext.read.load(path.getCanonicalPath),
+ spark.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@@ -58,7 +58,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
withTempPath { f =>
val path = f.getAbsolutePath
Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
- assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
+ assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
index 3d69c8a187..a743cdde40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -41,13 +41,13 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
override val streamingTimeout = 20.seconds
before {
- assert(sqlContext.streams.active.isEmpty)
- sqlContext.streams.resetTerminated()
+ assert(spark.streams.active.isEmpty)
+ spark.streams.resetTerminated()
}
after {
- assert(sqlContext.streams.active.isEmpty)
- sqlContext.streams.resetTerminated()
+ assert(spark.streams.active.isEmpty)
+ spark.streams.resetTerminated()
}
testQuietly("listing") {
@@ -57,26 +57,26 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
withQueriesOn(ds1, ds2, ds3) { queries =>
require(queries.size === 3)
- assert(sqlContext.streams.active.toSet === queries.toSet)
+ assert(spark.streams.active.toSet === queries.toSet)
val (q1, q2, q3) = (queries(0), queries(1), queries(2))
- assert(sqlContext.streams.get(q1.name).eq(q1))
- assert(sqlContext.streams.get(q2.name).eq(q2))
- assert(sqlContext.streams.get(q3.name).eq(q3))
+ assert(spark.streams.get(q1.name).eq(q1))
+ assert(spark.streams.get(q2.name).eq(q2))
+ assert(spark.streams.get(q3.name).eq(q3))
intercept[IllegalArgumentException] {
- sqlContext.streams.get("non-existent-name")
+ spark.streams.get("non-existent-name")
}
q1.stop()
- assert(sqlContext.streams.active.toSet === Set(q2, q3))
+ assert(spark.streams.active.toSet === Set(q2, q3))
val ex1 = withClue("no error while getting non-active query") {
intercept[IllegalArgumentException] {
- sqlContext.streams.get(q1.name)
+ spark.streams.get(q1.name)
}
}
assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched")
- assert(sqlContext.streams.get(q2.name).eq(q2))
+ assert(spark.streams.get(q2.name).eq(q2))
m2.addData(0) // q2 should terminate with error
@@ -86,11 +86,11 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
}
withClue("no error while getting non-active query") {
intercept[IllegalArgumentException] {
- sqlContext.streams.get(q2.name).eq(q2)
+ spark.streams.get(q2.name).eq(q2)
}
}
- assert(sqlContext.streams.active.toSet === Set(q3))
+ assert(spark.streams.active.toSet === Set(q3))
}
}
@@ -98,7 +98,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
val datasets = Seq.fill(5)(makeDataset._2)
withQueriesOn(datasets: _*) { queries =>
require(queries.size === datasets.size)
- assert(sqlContext.streams.active.toSet === queries.toSet)
+ assert(spark.streams.active.toSet === queries.toSet)
// awaitAnyTermination should be blocking
testAwaitAnyTermination(ExpectBlocked)
@@ -112,7 +112,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
testAwaitAnyTermination(ExpectNotBlocked)
// Resetting termination should make awaitAnyTermination() blocking again
- sqlContext.streams.resetTerminated()
+ spark.streams.resetTerminated()
testAwaitAnyTermination(ExpectBlocked)
// Terminate a query asynchronously with exception and see awaitAnyTermination throws
@@ -125,7 +125,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
testAwaitAnyTermination(ExpectException[SparkException])
// Resetting termination should make awaitAnyTermination() blocking again
- sqlContext.streams.resetTerminated()
+ spark.streams.resetTerminated()
testAwaitAnyTermination(ExpectBlocked)
// Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
@@ -144,7 +144,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
val datasets = Seq.fill(6)(makeDataset._2)
withQueriesOn(datasets: _*) { queries =>
require(queries.size === datasets.size)
- assert(sqlContext.streams.active.toSet === queries.toSet)
+ assert(spark.streams.active.toSet === queries.toSet)
// awaitAnyTermination should be blocking or non-blocking depending on timeout values
testAwaitAnyTermination(
@@ -173,7 +173,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true)
// Resetting termination should make awaitAnyTermination() blocking again
- sqlContext.streams.resetTerminated()
+ spark.streams.resetTerminated()
testAwaitAnyTermination(
ExpectBlocked,
awaitTimeout = 4 seconds,
@@ -196,7 +196,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
testBehaviorFor = 4 seconds)
// Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked
- sqlContext.streams.resetTerminated()
+ spark.streams.resetTerminated()
val q3 = stopRandomQueryAsync(2 seconds, withError = true)
testAwaitAnyTermination(
ExpectNotBlocked,
@@ -214,7 +214,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
// Terminate multiple queries, one with failure and see whether awaitAnyTermination throws
// the exception
- sqlContext.streams.resetTerminated()
+ spark.streams.resetTerminated()
val q4 = stopRandomQueryAsync(10 milliseconds, withError = false)
testAwaitAnyTermination(
@@ -238,7 +238,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
val df = ds.toDF
val metadataRoot =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
- query = sqlContext
+ query = spark
.streams
.startQuery(
StreamExecution.nextName,
@@ -272,10 +272,10 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
def awaitTermFunc(): Unit = {
if (awaitTimeout != null && awaitTimeout.toMillis > 0) {
- val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis)
+ val returnedValue = spark.streams.awaitAnyTermination(awaitTimeout.toMillis)
assert(returnedValue === expectedReturnedValue, "Returned value does not match expected")
} else {
- sqlContext.streams.awaitAnyTermination()
+ spark.streams.awaitAnyTermination()
}
}
@@ -287,7 +287,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
import scala.concurrent.ExecutionContext.Implicits.global
- val activeQueries = sqlContext.streams.active
+ val activeQueries = spark.streams.active
val queryToStop = activeQueries(Random.nextInt(activeQueries.length))
Future {
Thread.sleep(stopAfter.toMillis)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index c7b2b99822..cb53b2b1aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -54,18 +54,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
override def sourceSchema(
- sqlContext: SQLContext,
+ spark: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
LastOptions.parameters = parameters
LastOptions.schema = schema
- LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters)
+ LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters)
("dummySource", fakeSchema)
}
override def createSource(
- sqlContext: SQLContext,
+ spark: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
@@ -73,14 +73,14 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
LastOptions.parameters = parameters
LastOptions.schema = schema
LastOptions.mockStreamSourceProvider.createSource(
- sqlContext, metadataPath, schema, providerName, parameters)
+ spark, metadataPath, schema, providerName, parameters)
new Source {
override def schema: StructType = fakeSchema
override def getOffset: Option[Offset] = Some(new LongOffset(0))
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
- import sqlContext.implicits._
+ import spark.implicits._
Seq[Int]().toDS().toDF()
}
@@ -88,12 +88,12 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
override def createSink(
- sqlContext: SQLContext,
+ spark: SQLContext,
parameters: Map[String, String],
partitionColumns: Seq[String]): Sink = {
LastOptions.parameters = parameters
LastOptions.partitionColumns = partitionColumns
- LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns)
+ LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns)
new Sink {
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
}
@@ -107,11 +107,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
after {
- sqlContext.streams.active.foreach(_.stop())
+ spark.streams.active.foreach(_.stop())
}
test("resolve default source") {
- sqlContext.read
+ spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
.write
@@ -122,7 +122,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("resolve full class") {
- sqlContext.read
+ spark.read
.format("org.apache.spark.sql.streaming.test.DefaultSource")
.stream()
.write
@@ -136,7 +136,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
val map = new java.util.HashMap[String, String]
map.put("opt3", "3")
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.option("opt1", "1")
.options(Map("opt2" -> "2"))
@@ -164,7 +164,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("partitioning") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
@@ -204,7 +204,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("stream paths") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.option("checkpointLocation", newMetadataDir)
.stream("/test")
@@ -223,7 +223,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("test different data types for options") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.option("intOpt", 56)
.option("boolOpt", false)
@@ -253,7 +253,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
/** Start a query with a specific name */
def startQueryWithName(name: String = ""): ContinuousQuery = {
- sqlContext.read
+ spark.read
.format("org.apache.spark.sql.streaming.test")
.stream("/test")
.write
@@ -265,7 +265,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
/** Start a query without specifying a name */
def startQueryWithoutName(): ContinuousQuery = {
- sqlContext.read
+ spark.read
.format("org.apache.spark.sql.streaming.test")
.stream("/test")
.write
@@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
/** Get the names of active streams */
def activeStreamNames: Set[String] = {
- val streams = sqlContext.streams.active
+ val streams = spark.streams.active
val names = streams.map(_.name).toSet
assert(streams.length === names.size, s"names of active queries are not unique: $names")
names
@@ -307,11 +307,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
q1.stop()
val q5 = startQueryWithName("name")
assert(activeStreamNames.contains("name"))
- sqlContext.streams.active.foreach(_.stop())
+ spark.streams.active.foreach(_.stop())
}
test("trigger") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream("/test")
@@ -339,11 +339,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
val checkpointLocation = newMetadataDir
- val df1 = sqlContext.read
+ val df1 = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
- val df2 = sqlContext.read
+ val df2 = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
@@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
q.stop()
verify(LastOptions.mockStreamSourceProvider).createSource(
- sqlContext,
+ spark.wrapped,
checkpointLocation + "/sources/0",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
verify(LastOptions.mockStreamSourceProvider).createSource(
- sqlContext,
+ spark.wrapped,
checkpointLocation + "/sources/1",
None,
"org.apache.spark.sql.streaming.test",
@@ -372,35 +372,35 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
test("check trigger() can only be called on continuous queries") {
- val df = sqlContext.read.text(newTextInput)
+ val df = spark.read.text(newTextInput)
val w = df.write.option("checkpointLocation", newMetadataDir)
val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds")))
assert(e.getMessage == "trigger() can only be called on continuous queries;")
}
test("check queryName() can only be called on continuous queries") {
- val df = sqlContext.read.text(newTextInput)
+ val df = spark.read.text(newTextInput)
val w = df.write.option("checkpointLocation", newMetadataDir)
val e = intercept[AnalysisException](w.queryName("queryName"))
assert(e.getMessage == "queryName() can only be called on continuous queries;")
}
test("check startStream() can only be called on continuous queries") {
- val df = sqlContext.read.text(newTextInput)
+ val df = spark.read.text(newTextInput)
val w = df.write.option("checkpointLocation", newMetadataDir)
val e = intercept[AnalysisException](w.startStream())
assert(e.getMessage == "startStream() can only be called on continuous queries;")
}
test("check startStream(path) can only be called on continuous queries") {
- val df = sqlContext.read.text(newTextInput)
+ val df = spark.read.text(newTextInput)
val w = df.write.option("checkpointLocation", newMetadataDir)
val e = intercept[AnalysisException](w.startStream("non_exist_path"))
assert(e.getMessage == "startStream() can only be called on continuous queries;")
}
test("check mode(SaveMode) can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -409,7 +409,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check mode(string) can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -418,7 +418,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check bucketBy() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -427,7 +427,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check sortBy() can only be called on non-continuous queries;") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -436,7 +436,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check save(path) can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -445,7 +445,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check save() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -454,7 +454,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check insertInto() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -463,7 +463,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check saveAsTable() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -472,7 +472,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check jdbc() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -481,7 +481,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check json() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -490,7 +490,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check parquet() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -499,7 +499,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check orc() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -508,7 +508,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check text() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
@@ -517,7 +517,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
}
test("check csv() can only be called on non-continuous queries") {
- val df = sqlContext.read
+ val df = spark.read
.format("org.apache.spark.sql.streaming.test")
.stream()
val w = df.write
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index e937fc3e87..6238b74ffa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -40,11 +40,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
+ val hadoopConf = spark.sparkContext.hadoopConfiguration
val fileFormat = new parquet.DefaultSource()
def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = {
- val df = sqlContext
+ val df = spark
.range(start, end, 1, numPartitions)
.select($"id", lit(100).as("data"))
val writer = new FileStreamSinkWriter(
@@ -56,7 +56,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
val files1 = writeRange(0, 10, 2)
assert(files1.size === 2, s"unexpected number of files: $files1")
checkFilesExist(path, files1, "file not written")
- checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100)))
+ checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100)))
// Append and check whether new files are written correctly and old files still exist
val files2 = writeRange(10, 20, 3)
@@ -64,7 +64,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
assert(files2.intersect(files1).isEmpty, "old files returned")
checkFilesExist(path, files2, s"New file not written")
checkFilesExist(path, files1, s"Old file not found")
- checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100)))
+ checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100)))
}
test("FileStreamSinkWriter - partitioned data") {
@@ -72,11 +72,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
+ val hadoopConf = spark.sparkContext.hadoopConfiguration
val fileFormat = new parquet.DefaultSource()
def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = {
- val df = sqlContext
+ val df = spark
.range(start, end, 1, numPartitions)
.flatMap(x => Iterator(x, x, x)).toDF("id")
.select($"id", lit(100).as("data1"), lit(1000).as("data2"))
@@ -103,7 +103,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
checkOneFileWrittenPerKey(0 until 10, files1)
val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _))
- checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1)
+ checkAnswer(spark.read.load(path.getCanonicalPath), answer1)
// Append and check whether new files are written correctly and old files still exist
val files2 = writeRange(0, 20, 3)
@@ -114,7 +114,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
checkOneFileWrittenPerKey(0 until 20, files2)
val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _))
- checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1 ++ answer2)
+ checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2)
}
test("FileStreamSink - unpartitioned writing and batch reading") {
@@ -139,7 +139,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
query.processAllAvailable()
}
- val outputDf = sqlContext.read.parquet(outputDir).as[Int]
+ val outputDf = spark.read.parquet(outputDir).as[Int]
checkDataset(outputDf, 1, 2, 3)
} finally {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index a62852b512..4b95d65627 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -103,9 +103,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
val reader =
if (schema.isDefined) {
- sqlContext.read.format(format).schema(schema.get)
+ spark.read.format(format).schema(schema.get)
} else {
- sqlContext.read.format(format)
+ spark.read.format(format)
}
reader.stream(path)
}
@@ -149,7 +149,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
format: Option[String],
path: Option[String],
schema: Option[StructType] = None): StructType = {
- val reader = sqlContext.read
+ val reader = spark.read
format.foreach(reader.format)
schema.foreach(reader.schema)
val df =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
index 50703e532f..4efb7cf52d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
@@ -100,7 +100,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext {
}
writer.start()
- val input = sqlContext.read.format("text").stream(inputDir)
+ val input = spark.read.format("text").stream(inputDir)
def startStream(): ContinuousQuery = {
val output = input
@@ -150,6 +150,6 @@ class FileStressSuite extends StreamTest with SharedSQLContext {
streamThread.join()
logError(s"Stream restarted $failures times.")
- assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords)
+ assert(spark.read.parquet(outputDir).distinct().count() == numRecords)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
index 74ca3977d6..09c35bbf2c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
@@ -44,13 +44,13 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext {
query.processAllAvailable()
checkDataset(
- sqlContext.table("memStream").as[Int],
+ spark.table("memStream").as[Int],
1, 2, 3)
input.addData(4, 5, 6)
query.processAllAvailable()
checkDataset(
- sqlContext.table("memStream").as[Int],
+ spark.table("memStream").as[Int],
1, 2, 3, 4, 5, 6)
query.stop()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index bcd3cba55a..6a8b280174 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -94,7 +94,7 @@ class StreamSuite extends StreamTest with SharedSQLContext {
.startStream(outputDir.getAbsolutePath)
try {
query.processAllAvailable()
- val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
+ val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long]
checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
} finally {
query.stop()
@@ -103,7 +103,7 @@ class StreamSuite extends StreamTest with SharedSQLContext {
}
}
- val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
+ val df = spark.read.format(classOf[FakeDefaultSource].getName).stream()
assertDF(df)
assertDF(df)
}
@@ -162,13 +162,13 @@ class FakeDefaultSource extends StreamSourceProvider {
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
override def sourceSchema(
- sqlContext: SQLContext,
+ spark: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema)
override def createSource(
- sqlContext: SQLContext,
+ spark: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
@@ -190,7 +190,7 @@ class FakeDefaultSource extends StreamSourceProvider {
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
- sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
+ spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 7fa6760b71..03369c5a48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -20,17 +20,17 @@ package org.apache.spark.sql.test
import java.nio.charset.StandardCharsets
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits}
/**
* A collection of sample data used in SQL tests.
*/
private[sql] trait SQLTestData { self =>
- protected def sqlContext: SQLContext
+ protected def spark: SparkSession
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.sqlContext
+ protected override def _sqlContext: SQLContext = self.spark.wrapped
}
import internalImplicits._
@@ -39,21 +39,21 @@ private[sql] trait SQLTestData { self =>
// Note: all test data should be lazy because the SQLContext is not set up yet.
protected lazy val emptyTestData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
Seq.empty[Int].map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("emptyTestData")
df
}
protected lazy val testData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("testData")
df
}
protected lazy val testData2: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
@@ -65,7 +65,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val testData3: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
df.registerTempTable("testData3")
@@ -73,14 +73,14 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val negativeData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
df.registerTempTable("negativeData")
df
}
protected lazy val largeAndSmallInts: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
@@ -92,7 +92,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val decimalData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
@@ -104,7 +104,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val binaryData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) ::
BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) ::
@@ -115,7 +115,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val upperCaseData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
@@ -127,7 +127,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val lowerCaseData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
@@ -137,7 +137,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val arrayData: RDD[ArrayData] = {
- val rdd = sqlContext.sparkContext.parallelize(
+ val rdd = spark.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
rdd.toDF().registerTempTable("arrayData")
@@ -145,7 +145,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val mapData: RDD[MapData] = {
- val rdd = sqlContext.sparkContext.parallelize(
+ val rdd = spark.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
@@ -156,13 +156,13 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val repeatedData: RDD[StringData] = {
- val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
+ val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("repeatedData")
rdd
}
protected lazy val nullableRepeatedData: RDD[StringData] = {
- val rdd = sqlContext.sparkContext.parallelize(
+ val rdd = spark.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("nullableRepeatedData")
@@ -170,7 +170,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val nullInts: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
@@ -180,7 +180,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val allNulls: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
@@ -190,7 +190,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val nullStrings: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
@@ -199,13 +199,13 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val tableName: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
+ val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF()
df.registerTempTable("tableName")
df
}
protected lazy val unparsedStrings: RDD[String] = {
- sqlContext.sparkContext.parallelize(
+ spark.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
@@ -214,13 +214,13 @@ private[sql] trait SQLTestData { self =>
// An RDD with 4 elements and 8 partitions
protected lazy val withEmptyParts: RDD[IntField] = {
- val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+ val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8)
rdd.toDF().registerTempTable("withEmptyParts")
rdd
}
protected lazy val person: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
df.registerTempTable("person")
@@ -228,7 +228,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val salary: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
df.registerTempTable("salary")
@@ -236,7 +236,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val complexData: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
Nil).toDF()
@@ -245,7 +245,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val courseSales: DataFrame = {
- val df = sqlContext.sparkContext.parallelize(
+ val df = spark.sparkContext.parallelize(
CourseSales("dotNET", 2012, 10000) ::
CourseSales("Java", 2012, 20000) ::
CourseSales("dotNET", 2012, 5000) ::
@@ -259,7 +259,7 @@ private[sql] trait SQLTestData { self =>
* Initialize all test data such that all temp tables are properly registered.
*/
def loadTestData(): Unit = {
- assert(sqlContext != null, "attempted to initialize test data before SQLContext.")
+ assert(spark != null, "attempted to initialize test data before SparkSession.")
emptyTestData
testData
testData2
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 6d2b95e83a..a49a8c9f2c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -50,23 +50,23 @@ private[sql] trait SQLTestUtils
with BeforeAndAfterAll
with SQLTestData { self =>
- protected def sparkContext = sqlContext.sparkContext
+ protected def sparkContext = spark.sparkContext
// Whether to materialize all test data before the first test is run
private var loadTestDataBeforeTests = false
// Shorthand for running a query using our SQLContext
- protected lazy val sql = sqlContext.sql _
+ protected lazy val sql = spark.sql _
/**
* A helper object for importing SQL implicits.
*
- * Note that the alternative of importing `sqlContext.implicits._` is not possible here.
+ * Note that the alternative of importing `spark.implicits._` is not possible here.
* This is because we create the [[SQLContext]] immediately before the first test is run,
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self.sqlContext
+ protected override def _sqlContext: SQLContext = self.spark.wrapped
}
/**
@@ -92,12 +92,12 @@ private[sql] trait SQLTestUtils
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
- (keys, values).zipped.foreach(sqlContext.conf.setConfString)
+ val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
+ (keys, values).zipped.foreach(spark.conf.set)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
- case (key, None) => sqlContext.conf.unsetConf(key)
+ case (key, Some(value)) => spark.conf.set(key, value)
+ case (key, None) => spark.conf.unset(key)
}
}
}
@@ -138,9 +138,9 @@ private[sql] trait SQLTestUtils
// temp tables that never got created.
functions.foreach { case (functionName, isTemporary) =>
val withTemporary = if (isTemporary) "TEMPORARY" else ""
- sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName")
+ spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName")
assert(
- !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
+ !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
s"Function $functionName should have been dropped. But, it still exists.")
}
}
@@ -153,7 +153,7 @@ private[sql] trait SQLTestUtils
try f finally {
// If the test failed part way, we don't want to mask the failure by failing to remove
// temp tables that never got created.
- try tableNames.foreach(sqlContext.dropTempTable) catch {
+ try tableNames.foreach(spark.catalog.dropTempTable) catch {
case _: NoSuchTableException =>
}
}
@@ -165,7 +165,7 @@ private[sql] trait SQLTestUtils
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
- sqlContext.sql(s"DROP TABLE IF EXISTS $name")
+ spark.sql(s"DROP TABLE IF EXISTS $name")
}
}
}
@@ -176,7 +176,7 @@ private[sql] trait SQLTestUtils
protected def withView(viewNames: String*)(f: => Unit): Unit = {
try f finally {
viewNames.foreach { name =>
- sqlContext.sql(s"DROP VIEW IF EXISTS $name")
+ spark.sql(s"DROP VIEW IF EXISTS $name")
}
}
}
@@ -191,12 +191,12 @@ private[sql] trait SQLTestUtils
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
try {
- sqlContext.sql(s"CREATE DATABASE $dbName")
+ spark.sql(s"CREATE DATABASE $dbName")
} catch { case cause: Throwable =>
fail("Failed to create temporary database", cause)
}
- try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
+ try f(dbName) finally spark.sql(s"DROP DATABASE $dbName CASCADE")
}
/**
@@ -204,8 +204,8 @@ private[sql] trait SQLTestUtils
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
- sqlContext.sessionState.catalog.setCurrentDatabase(db)
- try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default")
+ spark.sessionState.catalog.setCurrentDatabase(db)
+ try f finally spark.sessionState.catalog.setCurrentDatabase("default")
}
/**
@@ -221,7 +221,7 @@ private[sql] trait SQLTestUtils
.execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
- sqlContext.createDataFrame(childRDD, schema)
+ spark.createDataFrame(childRDD, schema)
}
/**
@@ -229,7 +229,7 @@ private[sql] trait SQLTestUtils
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
- Dataset.ofRows(sqlContext.sparkSession, plan)
+ Dataset.ofRows(spark, plan)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index 914c6a5509..620bfa995a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -17,37 +17,42 @@
package org.apache.spark.sql.test
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{SparkSession, SQLContext}
/**
- * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
+ * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
*/
trait SharedSQLContext extends SQLTestUtils {
protected val sparkConf = new SparkConf()
/**
- * The [[TestSQLContext]] to use for all tests in this suite.
+ * The [[TestSparkSession]] to use for all tests in this suite.
*
* By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
* mode with the default test configurations.
*/
- private var _ctx: TestSQLContext = null
+ private var _spark: TestSparkSession = null
+
+ /**
+ * The [[TestSparkSession]] to use for all tests in this suite.
+ */
+ protected implicit def spark: SparkSession = _spark
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
- protected implicit def sqlContext: SQLContext = _ctx
+ protected implicit def sqlContext: SQLContext = _spark.wrapped
/**
- * Initialize the [[TestSQLContext]].
+ * Initialize the [[TestSparkSession]].
*/
protected override def beforeAll(): Unit = {
SQLContext.clearSqlListener()
- if (_ctx == null) {
- _ctx = new TestSQLContext(sparkConf)
+ if (_spark == null) {
+ _spark = new TestSparkSession(sparkConf)
}
// Ensure we have initialized the context before calling parent code
super.beforeAll()
@@ -58,9 +63,9 @@ trait SharedSQLContext extends SQLTestUtils {
*/
protected override def afterAll(): Unit = {
try {
- if (_ctx != null) {
- _ctx.sparkContext.stop()
- _ctx = null
+ if (_spark != null) {
+ _spark.stop()
+ _spark = null
}
} finally {
super.afterAll()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 5ef80b9aa3..785e3452a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -18,44 +18,32 @@
package org.apache.spark.sql.test
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.{SessionState, SQLConf}
/**
- * A special [[SQLContext]] prepared for testing.
+ * A special [[SparkSession]] prepared for testing.
*/
-private[sql] class TestSQLContext(
- @transient override val sparkSession: SparkSession,
- isRootContext: Boolean)
- extends SQLContext(sparkSession, isRootContext) { self =>
-
- def this(sc: SparkContext) {
- this(new TestSparkSession(sc), true)
- }
-
+private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
def this(sparkConf: SparkConf) {
this(new SparkContext("local[2]", "test-sql-context",
sparkConf.set("spark.sql.testkey", "true")))
}
def this() {
- this(new SparkConf)
- }
-
- // Needed for Java tests
- def loadTestData(): Unit = {
- testData.loadTestData()
- }
-
- private object testData extends SQLTestData {
- protected override def sqlContext: SQLContext = self
+ this {
+ val conf = new SparkConf()
+ conf.set("spark.sql.testkey", "true")
+
+ val spark = SparkSession.builder
+ .master("local[2]")
+ .appName("test-sql-context")
+ .config(conf)
+ .getOrCreate()
+ spark.sparkContext
+ }
}
-}
-
-
-private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
-
@transient
protected[sql] override lazy val sessionState: SessionState = new SessionState(self) {
override lazy val conf: SQLConf = {
@@ -70,6 +58,14 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
}
+ // Needed for Java tests
+ def loadTestData(): Unit = {
+ testData.loadTestData()
+ }
+
+ private object testData extends SQLTestData {
+ protected override def spark: SparkSession = self
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
index 54acd4db3c..8788898fc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
@@ -36,11 +36,11 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with
import testImplicits._
after {
- sqlContext.streams.active.foreach(_.stop())
- assert(sqlContext.streams.active.isEmpty)
+ spark.streams.active.foreach(_.stop())
+ assert(spark.streams.active.isEmpty)
assert(addedListeners.isEmpty)
// Make sure we don't leak any events to the next test
- sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000)
+ spark.sparkContext.listenerBus.waitUntilEmpty(10000)
}
test("single listener") {
@@ -112,17 +112,17 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with
val listener1 = new QueryStatusCollector
val listener2 = new QueryStatusCollector
- sqlContext.streams.addListener(listener1)
+ spark.streams.addListener(listener1)
assert(isListenerActive(listener1) === true)
assert(isListenerActive(listener2) === false)
- sqlContext.streams.addListener(listener2)
+ spark.streams.addListener(listener2)
assert(isListenerActive(listener1) === true)
assert(isListenerActive(listener2) === true)
- sqlContext.streams.removeListener(listener1)
+ spark.streams.removeListener(listener1)
assert(isListenerActive(listener1) === false)
assert(isListenerActive(listener2) === true)
} finally {
- addedListeners.foreach(sqlContext.streams.removeListener)
+ addedListeners.foreach(spark.streams.removeListener)
}
}
@@ -146,18 +146,18 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with
private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = {
try {
failAfter(1 minute) {
- sqlContext.streams.addListener(listener)
+ spark.streams.addListener(listener)
body
}
} finally {
- sqlContext.streams.removeListener(listener)
+ spark.streams.removeListener(listener)
}
}
private def addedListeners(): Array[ContinuousQueryListener] = {
val listenerBusMethod =
PrivateMethod[ContinuousQueryListenerBus]('listenerBus)
- val listenerBus = sqlContext.streams invokePrivate listenerBusMethod()
+ val listenerBus = spark.streams invokePrivate listenerBusMethod()
listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener])
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 8a0578c1ff..3ae5ce610d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -39,7 +39,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
metrics += ((funcName, qe, duration))
}
}
- sqlContext.listenerManager.register(listener)
+ spark.listenerManager.register(listener)
val df = Seq(1 -> "a").toDF("i", "j")
df.select("i").collect()
@@ -55,7 +55,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate])
assert(metrics(1)._3 > 0)
- sqlContext.listenerManager.unregister(listener)
+ spark.listenerManager.unregister(listener)
}
test("execute callback functions when a DataFrame action failed") {
@@ -68,7 +68,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
// Only test failed case here, so no need to implement `onSuccess`
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
}
- sqlContext.listenerManager.register(listener)
+ spark.listenerManager.register(listener)
val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") }
val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j")
@@ -82,7 +82,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
assert(metrics(0)._3.getMessage == e.getMessage)
- sqlContext.listenerManager.unregister(listener)
+ spark.listenerManager.unregister(listener)
}
test("get numRows metrics by callback") {
@@ -99,7 +99,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
metrics += metric.value
}
}
- sqlContext.listenerManager.register(listener)
+ spark.listenerManager.register(listener)
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
df.collect()
@@ -111,7 +111,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(1) === 1)
assert(metrics(2) === 2)
- sqlContext.listenerManager.unregister(listener)
+ spark.listenerManager.unregister(listener)
}
// TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never
@@ -131,10 +131,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
metrics += bottomAgg.longMetric("dataSize").value
}
}
- sqlContext.listenerManager.register(listener)
+ spark.listenerManager.register(listener)
val sparkListener = new SaveInfoListener
- sqlContext.sparkContext.addSparkListener(sparkListener)
+ spark.sparkContext.addSparkListener(sparkListener)
val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j")
df.groupBy("i").count().collect()
@@ -157,6 +157,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(0) == topAggDataSize)
assert(metrics(1) == bottomAggDataSize)
- sqlContext.listenerManager.unregister(listener)
+ spark.listenerManager.unregister(listener)
}
}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
index 154ada3daa..9bf84ab1fb 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.hive.test
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SQLContext
trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
- protected val sqlContext: SQLContext = TestHive
+ protected val spark: SparkSession = TestHive.sparkSession
protected val hiveContext: TestHiveContext = TestHive
protected override def afterAll(): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index a7782abc39..72736ee55b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -34,13 +34,13 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
val bytes = Array[Byte](1, 2, 3, 4)
Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0")
- sqlContext
+ spark
.range(10)
.select('id as 'key, concat(lit("val_"), 'id) as 'value)
.write
.saveAsTable("t1")
- sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
}
override protected def afterAll(): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
index 34c2773581..9abefa5f28 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
@@ -33,16 +33,16 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
sql("DROP TABLE IF EXISTS parquet_t2")
sql("DROP TABLE IF EXISTS t0")
- sqlContext.range(10).write.saveAsTable("parquet_t0")
+ spark.range(10).write.saveAsTable("parquet_t0")
sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0")
- sqlContext
+ spark
.range(10)
.select('id as 'key, concat(lit("val_"), 'id) as 'value)
.write
.saveAsTable("parquet_t1")
- sqlContext
+ spark
.range(10)
.select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd)
.write
@@ -52,7 +52,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1))
}
- sqlContext
+ spark
.range(10)
.select(
createArray('id).as("arr"),
@@ -394,7 +394,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
Seq("orc", "json", "parquet").foreach { format =>
val tableName = s"${format}_parquet_t0"
withTable(tableName) {
- sqlContext.range(10).write.format(format).saveAsTable(tableName)
+ spark.range(10).write.format(format).saveAsTable(tableName)
checkHiveQl(s"SELECT id FROM $tableName")
}
}
@@ -458,7 +458,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("plans with non-SQL expressions") {
- sqlContext.udf.register("foo", (_: Int) * 2)
+ spark.udf.register("foo", (_: Int) * 2)
intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
index 27c9e992de..31755f56ec 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
@@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin)
}
- checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan))
+ checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan))
}
protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index 61910b8e6b..093cd3a96c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -30,8 +30,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
override protected def beforeEach(): Unit = {
super.beforeEach()
- if (sqlContext.tableNames().contains("src")) {
- sqlContext.dropTempTable("src")
+ if (spark.wrapped.tableNames().contains("src")) {
+ spark.catalog.dropTempTable("src")
}
Seq((1, "")).toDF("key", "value").registerTempTable("src")
Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes")
@@ -39,8 +39,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
override protected def afterEach(): Unit = {
try {
- sqlContext.dropTempTable("src")
- sqlContext.dropTempTable("dupAttributes")
+ spark.catalog.dropTempTable("src")
+ spark.catalog.dropTempTable("dupAttributes")
} finally {
super.afterEach()
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index a717a9978e..bfe559f0b2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -555,7 +555,7 @@ object SparkSQLConfTest extends Logging {
object SPARK_9757 extends QueryTest {
import org.apache.spark.sql.functions._
- protected var sqlContext: SQLContext = _
+ protected var spark: SparkSession = _
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
@@ -567,7 +567,7 @@ object SPARK_9757 extends QueryTest {
.set("spark.ui.enabled", "false"))
val hiveContext = new TestHiveContext(sparkContext)
- sqlContext = hiveContext
+ spark = hiveContext.sparkSession
import hiveContext.implicits._
val dir = Utils.createTempDir()
@@ -602,7 +602,7 @@ object SPARK_9757 extends QueryTest {
object SPARK_11009 extends QueryTest {
import org.apache.spark.sql.functions._
- protected var sqlContext: SQLContext = _
+ protected var spark: SparkSession = _
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
@@ -613,10 +613,10 @@ object SPARK_11009 extends QueryTest {
.set("spark.sql.shuffle.partitions", "100"))
val hiveContext = new TestHiveContext(sparkContext)
- sqlContext = hiveContext
+ spark = hiveContext.sparkSession
try {
- val df = sqlContext.range(1 << 20)
+ val df = spark.range(1 << 20)
val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0")
@@ -633,7 +633,7 @@ object SPARK_14244 extends QueryTest {
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
- protected var sqlContext: SQLContext = _
+ protected var spark: SparkSession = _
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
@@ -644,13 +644,13 @@ object SPARK_14244 extends QueryTest {
.set("spark.sql.shuffle.partitions", "100"))
val hiveContext = new TestHiveContext(sparkContext)
- sqlContext = hiveContext
+ spark = hiveContext.sparkSession
import hiveContext.implicits._
try {
val window = Window.orderBy('id)
- val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist)
+ val df = spark.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist)
checkAnswer(df, Seq(Row(0.5D), Row(1.0D)))
} finally {
sparkContext.stop()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 52aba328de..82d3e49f92 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -251,7 +251,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
// this will pick up the output partitioning from the table definition
- sqlContext.table("source").write.insertInto("partitioned")
+ spark.table("source").write.insertInto("partitioned")
checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq)
}
@@ -272,7 +272,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql(
"""CREATE TABLE partitioned (id bigint, data string)
|PARTITIONED BY (part1 string, part2 string)""".stripMargin)
- sqlContext.table("source").write.insertInto("partitioned")
+ spark.table("source").write.insertInto("partitioned")
checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq)
}
@@ -283,7 +283,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data")
- val logical = InsertIntoTable(sqlContext.table("partitioned").logicalPlan,
+ val logical = InsertIntoTable(spark.table("partitioned").logicalPlan,
Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false)
assert(!logical.resolved, "Should not resolve: missing partition data")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 78c8f0043d..b2a80e70be 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -374,7 +374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
val expectedPath =
sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable"))
val filesystemPath = new Path(expectedPath)
- val fs = filesystemPath.getFileSystem(sqlContext.sessionState.newHadoopConf())
+ val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf())
if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true)
// It is a managed table when we do not specify the location.
@@ -701,7 +701,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
// Manually create a metastore data source table.
CreateDataSourceTableUtils.createDataSourceTable(
- sparkSession = sqlContext.sparkSession,
+ sparkSession = spark,
tableIdent = TableIdentifier("wide_schema"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
@@ -891,18 +891,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
test("SPARK-8156:create table to specific database by 'use dbname' ") {
val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
- sqlContext.sql("""create database if not exists testdb8156""")
- sqlContext.sql("""use testdb8156""")
+ spark.sql("""create database if not exists testdb8156""")
+ spark.sql("""use testdb8156""")
df.write
.format("parquet")
.mode(SaveMode.Overwrite)
.saveAsTable("ttt3")
checkAnswer(
- sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"),
+ spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"),
Row("ttt3", false))
- sqlContext.sql("""use default""")
- sqlContext.sql("""drop database if exists testdb8156 CASCADE""")
+ spark.sql("""use default""")
+ spark.sql("""drop database if exists testdb8156 CASCADE""")
}
@@ -911,7 +911,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType)))
CreateDataSourceTableUtils.createDataSourceTable(
- sparkSession = sqlContext.sparkSession,
+ sparkSession = spark,
tableIdent = TableIdentifier("not_skip_hive_metadata"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
@@ -926,7 +926,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
.forall(column => CatalystSqlParser.parseDataType(column.dataType) == StringType))
CreateDataSourceTableUtils.createDataSourceTable(
- sparkSession = sqlContext.sparkSession,
+ sparkSession = spark,
tableIdent = TableIdentifier("skip_hive_metadata"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 850cb1eda5..6c9ce208db 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
- private lazy val df = sqlContext.range(10).coalesce(1).toDF()
+ private lazy val df = spark.range(10).coalesce(1).toDF()
private def checkTablePath(dbName: String, tableName: String): Unit = {
val metastoreTable = hiveContext.sharedState.externalCatalog.getTable(dbName, tableName)
@@ -36,12 +36,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(sqlContext.tableNames().contains("t"))
- checkAnswer(sqlContext.table("t"), df)
+ assert(spark.wrapped.tableNames().contains("t"))
+ checkAnswer(spark.table("t"), df)
}
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table(s"$db.t"), df)
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
}
@@ -50,8 +50,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
test(s"saveAsTable() to non-default database - without USE - Overwrite") {
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table(s"$db.t"), df)
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
}
@@ -64,9 +64,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val path = dir.getCanonicalPath
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
- sqlContext.createExternalTable("t", path, "parquet")
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table("t"), df)
+ spark.catalog.createExternalTable("t", path, "parquet")
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table("t"), df)
sql(
s"""
@@ -76,8 +76,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(sqlContext.tableNames(db).contains("t1"))
- checkAnswer(sqlContext.table("t1"), df)
+ assert(spark.wrapped.tableNames(db).contains("t1"))
+ checkAnswer(spark.table("t1"), df)
}
}
}
@@ -88,10 +88,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempPath { dir =>
val path = dir.getCanonicalPath
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
- sqlContext.createExternalTable(s"$db.t", path, "parquet")
+ spark.catalog.createExternalTable(s"$db.t", path, "parquet")
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table(s"$db.t"), df)
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table(s"$db.t"), df)
sql(
s"""
@@ -101,8 +101,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
- assert(sqlContext.tableNames(db).contains("t1"))
- checkAnswer(sqlContext.table(s"$db.t1"), df)
+ assert(spark.wrapped.tableNames(db).contains("t1"))
+ checkAnswer(spark.table(s"$db.t1"), df)
}
}
}
@@ -112,12 +112,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
- assert(sqlContext.tableNames().contains("t"))
- checkAnswer(sqlContext.table("t"), df.union(df))
+ assert(spark.wrapped.tableNames().contains("t"))
+ checkAnswer(spark.table("t"), df.union(df))
}
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
}
@@ -127,8 +127,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
- assert(sqlContext.tableNames(db).contains("t"))
- checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
+ assert(spark.wrapped.tableNames(db).contains("t"))
+ checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
}
@@ -138,10 +138,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(sqlContext.tableNames().contains("t"))
+ assert(spark.wrapped.tableNames().contains("t"))
df.write.insertInto(s"$db.t")
- checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
+ checkAnswer(spark.table(s"$db.t"), df.union(df))
}
}
}
@@ -150,13 +150,13 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
- assert(sqlContext.tableNames().contains("t"))
+ assert(spark.wrapped.tableNames().contains("t"))
}
- assert(sqlContext.tableNames(db).contains("t"))
+ assert(spark.wrapped.tableNames(db).contains("t"))
df.write.insertInto(s"$db.t")
- checkAnswer(sqlContext.table(s"$db.t"), df.union(df))
+ checkAnswer(spark.table(s"$db.t"), df.union(df))
}
}
@@ -164,10 +164,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
sql("CREATE TABLE t (key INT)")
- checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame)
+ checkAnswer(spark.table("t"), spark.emptyDataFrame)
}
- checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame)
+ checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame)
}
}
@@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
sql(s"CREATE TABLE t (key INT)")
- assert(sqlContext.tableNames().contains("t"))
- assert(!sqlContext.tableNames("default").contains("t"))
+ assert(spark.wrapped.tableNames().contains("t"))
+ assert(!spark.wrapped.tableNames("default").contains("t"))
}
- assert(!sqlContext.tableNames().contains("t"))
- assert(sqlContext.tableNames(db).contains("t"))
+ assert(!spark.wrapped.tableNames().contains("t"))
+ assert(spark.wrapped.tableNames(db).contains("t"))
activateDatabase(db) {
sql(s"DROP TABLE t")
- assert(!sqlContext.tableNames().contains("t"))
- assert(!sqlContext.tableNames("default").contains("t"))
+ assert(!spark.wrapped.tableNames().contains("t"))
+ assert(!spark.wrapped.tableNames("default").contains("t"))
}
- assert(!sqlContext.tableNames().contains("t"))
- assert(!sqlContext.tableNames(db).contains("t"))
+ assert(!spark.wrapped.tableNames().contains("t"))
+ assert(!spark.wrapped.tableNames(db).contains("t"))
}
}
@@ -208,18 +208,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|LOCATION '$path'
""".stripMargin)
- checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame)
+ checkAnswer(spark.table("t"), spark.emptyDataFrame)
df.write.parquet(s"$path/p=1")
sql("ALTER TABLE t ADD PARTITION (p=1)")
sql("REFRESH TABLE t")
- checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1)))
+ checkAnswer(spark.table("t"), df.withColumn("p", lit(1)))
df.write.parquet(s"$path/p=2")
sql("ALTER TABLE t ADD PARTITION (p=2)")
hiveContext.sessionState.refreshTable("t")
checkAnswer(
- sqlContext.table("t"),
+ spark.table("t"),
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
@@ -240,18 +240,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|LOCATION '$path'
""".stripMargin)
- checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame)
+ checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame)
df.write.parquet(s"$path/p=1")
sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)")
sql(s"REFRESH TABLE $db.t")
- checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1)))
+ checkAnswer(spark.table(s"$db.t"), df.withColumn("p", lit(1)))
df.write.parquet(s"$path/p=2")
sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)")
hiveContext.sessionState.refreshTable(s"$db.t")
checkAnswer(
- sqlContext.table(s"$db.t"),
+ spark.table(s"$db.t"),
df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2))))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
index af4dc1beec..3f6418cbe8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
@@ -70,12 +70,12 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi
|$ddl
""".stripMargin)
- sqlContext.sql(ddl)
+ spark.sql(ddl)
- val schema = sqlContext.table("parquet_compat").schema
- val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1)
- sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data")
- sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data")
+ val schema = spark.table("parquet_compat").schema
+ val rowRDD = spark.sparkContext.parallelize(rows).coalesce(1)
+ spark.createDataFrame(rowRDD, schema).registerTempTable("data")
+ spark.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data")
}
}
@@ -84,7 +84,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi
// Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings.
// Have to assume all BINARY values are strings here.
withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") {
- checkAnswer(sqlContext.read.parquet(path), rows)
+ checkAnswer(spark.read.parquet(path), rows)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 0ba72b033f..0f416eb24d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -177,23 +177,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
(Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
data3.write.saveAsTable("agg3")
- val emptyDF = sqlContext.createDataFrame(
+ val emptyDF = spark.createDataFrame(
sparkContext.emptyRDD[Row],
StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
emptyDF.registerTempTable("emptyTable")
// Register UDAFs
- sqlContext.udf.register("mydoublesum", new MyDoubleSum)
- sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
- sqlContext.udf.register("longProductSum", new LongProductSum)
+ spark.udf.register("mydoublesum", new MyDoubleSum)
+ spark.udf.register("mydoubleavg", new MyDoubleAvg)
+ spark.udf.register("longProductSum", new LongProductSum)
}
override def afterAll(): Unit = {
try {
- sqlContext.sql("DROP TABLE IF EXISTS agg1")
- sqlContext.sql("DROP TABLE IF EXISTS agg2")
- sqlContext.sql("DROP TABLE IF EXISTS agg3")
- sqlContext.dropTempTable("emptyTable")
+ spark.sql("DROP TABLE IF EXISTS agg1")
+ spark.sql("DROP TABLE IF EXISTS agg2")
+ spark.sql("DROP TABLE IF EXISTS agg3")
+ spark.catalog.dropTempTable("emptyTable")
} finally {
super.afterAll()
}
@@ -210,7 +210,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("empty table") {
// If there is no GROUP BY clause and the table is empty, we will generate a single row.
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| AVG(value),
@@ -227,7 +227,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, 0, 0, 0, null, null, null, null, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| AVG(value),
@@ -246,7 +246,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// If there is a GROUP BY clause and the table is empty, there is no output.
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| AVG(value),
@@ -266,7 +266,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("null literal") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| AVG(null),
@@ -282,7 +282,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("only do grouping") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT key
|FROM agg1
@@ -291,7 +291,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT DISTINCT value1, key
|FROM agg2
@@ -308,7 +308,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT value1, key
|FROM agg2
@@ -326,7 +326,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT DISTINCT key
|FROM agg3
@@ -341,7 +341,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(Seq[Integer](3)) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT value1, key
|FROM agg3
@@ -363,7 +363,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("case in-sensitive resolution") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT avg(value), kEY - 100
|FROM agg1
@@ -372,7 +372,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT sum(distinct value1), kEY - 100, count(distinct value1)
|FROM agg2
@@ -381,7 +381,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT valUe * key - 100
|FROM agg1
@@ -397,7 +397,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("test average no key in output") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT avg(value)
|FROM agg1
@@ -408,7 +408,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("test average") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT key, avg(value)
|FROM agg1
@@ -417,7 +417,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT key, mean(value)
|FROM agg1
@@ -426,7 +426,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT avg(value), key
|FROM agg1
@@ -435,7 +435,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT avg(value) + 1.5, key + 10
|FROM agg1
@@ -444,7 +444,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT avg(value) FROM agg1
""".stripMargin),
@@ -456,7 +456,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// deterministic.
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| first_valUE(key),
@@ -472,7 +472,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| first_valUE(key),
@@ -491,7 +491,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("udaf") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| key,
@@ -511,7 +511,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("interpreted aggregate function") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT mydoublesum(value), key
|FROM agg1
@@ -520,14 +520,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT mydoublesum(value) FROM agg1
""".stripMargin),
Row(89.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT mydoublesum(null)
""".stripMargin),
@@ -536,7 +536,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("interpreted and expression-based aggregation functions") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT mydoublesum(value), key, avg(value)
|FROM agg1
@@ -548,7 +548,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(30.0, null, 10.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| mydoublesum(value + 1.5 * key),
@@ -568,7 +568,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("single distinct column set") {
// DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword.
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| min(distinct value1),
@@ -581,7 +581,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(-60, 70.0, 101.0/9.0, 5.6, 100))
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| mydoubleavg(distinct value1),
@@ -600,7 +600,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| key,
@@ -618,7 +618,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| count(value1),
@@ -637,7 +637,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("single distinct multiple columns set") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| key,
@@ -653,7 +653,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("multiple distinct multiple columns sets") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| key,
@@ -681,7 +681,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("test count") {
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| count(value2),
@@ -704,7 +704,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(0, null, 1, 1, null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT
| count(value2),
@@ -786,28 +786,28 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
covar_tab.registerTempTable("covar_tab")
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a < 1
""".stripMargin),
Row(null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a < 3
""".stripMargin),
Row(null) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT corr(b, c) FROM covar_tab WHERE a = 3
""".stripMargin),
Row(Double.NaN) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a
""".stripMargin),
@@ -818,7 +818,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(5, Double.NaN) ::
Row(6, Double.NaN) :: Nil)
- val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
+ val corr7 = spark.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
}
@@ -852,7 +852,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}
test("no aggregation function (SPARK-11486)") {
- val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
+ val df = spark.range(20).selectExpr("id", "repeat(id, 1) as s")
.groupBy("s").count()
.groupBy().count()
checkAnswer(df, Row(20) :: Nil)
@@ -906,8 +906,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}
// Create a DF for the schema with random data.
- val rdd = sqlContext.sparkContext.parallelize(data, 1)
- val df = sqlContext.createDataFrame(rdd, schema)
+ val rdd = spark.sparkContext.parallelize(data, 1)
+ val df = spark.createDataFrame(rdd, schema)
val allColumns = df.schema.fields.map(f => col(f.name))
val expectedAnswer =
@@ -924,7 +924,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
test("udaf without specifying inputSchema") {
withTempTable("noInputSchemaUDAF") {
- sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
+ spark.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
val data =
Row(1, Seq(Row(1), Row(2), Row(3))) ::
@@ -935,13 +935,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
StructField("key", IntegerType) ::
StructField("myArray",
ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
- sqlContext.createDataFrame(
+ spark.createDataFrame(
sparkContext.parallelize(data, 2),
schema)
.registerTempTable("noInputSchemaUDAF")
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT key, noInputSchema(myArray)
|FROM noInputSchemaUDAF
@@ -950,7 +950,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(1, 21) :: Row(2, -10) :: Nil)
checkAnswer(
- sqlContext.sql(
+ spark.sql(
"""
|SELECT noInputSchema(myArray)
|FROM noInputSchemaUDAF
@@ -976,7 +976,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
- val newActual = Dataset.ofRows(sqlContext.sparkSession, actual.logicalPlan)
+ val newActual = Dataset.ofRows(spark, actual.logicalPlan)
QueryTest.checkAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 0f23949d98..6dcc404636 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -36,7 +36,7 @@ class HiveDDLSuite
override def afterEach(): Unit = {
try {
// drop all databases, tables and functions after each test
- sqlContext.sessionState.catalog.reset()
+ spark.sessionState.catalog.reset()
} finally {
super.afterEach()
}
@@ -212,7 +212,7 @@ class HiveDDLSuite
test("drop views") {
withTable("tab1") {
val tabName = "tab1"
- sqlContext.range(10).write.saveAsTable("tab1")
+ spark.range(10).write.saveAsTable("tab1")
withView("view1") {
val viewName = "view1"
@@ -233,7 +233,7 @@ class HiveDDLSuite
test("alter views - rename") {
val tabName = "tab1"
withTable(tabName) {
- sqlContext.range(10).write.saveAsTable(tabName)
+ spark.range(10).write.saveAsTable(tabName)
val oldViewName = "view1"
val newViewName = "view2"
withView(oldViewName, newViewName) {
@@ -252,7 +252,7 @@ class HiveDDLSuite
test("alter views - set/unset tblproperties") {
val tabName = "tab1"
withTable(tabName) {
- sqlContext.range(10).write.saveAsTable(tabName)
+ spark.range(10).write.saveAsTable(tabName)
val viewName = "view1"
withView(viewName) {
val catalog = hiveContext.sessionState.catalog
@@ -290,7 +290,7 @@ class HiveDDLSuite
test("alter views and alter table - misuse") {
val tabName = "tab1"
withTable(tabName) {
- sqlContext.range(10).write.saveAsTable(tabName)
+ spark.range(10).write.saveAsTable(tabName)
val oldViewName = "view1"
val newViewName = "view2"
withView(oldViewName, newViewName) {
@@ -354,7 +354,7 @@ class HiveDDLSuite
test("drop view using drop table") {
withTable("tab1") {
- sqlContext.range(10).write.saveAsTable("tab1")
+ spark.range(10).write.saveAsTable("tab1")
withView("view1") {
sql("CREATE VIEW view1 AS SELECT * FROM tab1")
val message = intercept[AnalysisException] {
@@ -383,7 +383,7 @@ class HiveDDLSuite
}
private def createDatabaseWithLocation(tmpDir: File, dirExists: Boolean): Unit = {
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val dbName = "db1"
val tabName = "tab1"
val fs = new Path(tmpDir.toString).getFileSystem(hiveContext.sessionState.newHadoopConf())
@@ -442,7 +442,7 @@ class HiveDDLSuite
assert(!fs.exists(dbPath))
sql(s"CREATE DATABASE $dbName")
- val catalog = sqlContext.sessionState.catalog
+ val catalog = spark.sessionState.catalog
val expectedDBLocation = "file:" + appendTrailingSlash(dbPath.toString) + s"$dbName.db"
val db1 = catalog.getDatabaseMetadata(dbName)
assert(db1 == CatalogDatabase(
@@ -518,7 +518,7 @@ class HiveDDLSuite
test("desc table for data source table") {
withTable("tab1") {
val tabName = "tab1"
- sqlContext.range(1).write.format("json").saveAsTable(tabName)
+ spark.range(1).write.format("json").saveAsTable(tabName)
assert(sql(s"DESC $tabName").collect().length == 1)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d07ac56586..dd4321d1b6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -347,7 +347,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode")
}
- sqlContext.dropTempTable("testUDF")
+ spark.catalog.dropTempTable("testUDF")
}
test("Hive UDF in group by") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 1d597fe16d..2e4077df54 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -860,7 +860,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("Sorting columns are not in Generate") {
withTempTable("data") {
- sqlContext.range(1, 5)
+ spark.range(1, 5)
.select(array($"id", $"id" + 1).as("a"), $"id".as("b"), (lit(10) - $"id").as("c"))
.registerTempTable("data")
@@ -1081,7 +1081,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
// We don't support creating a temporary table while specifying a database
val message = intercept[AnalysisException] {
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE db.t
|USING parquet
@@ -1092,7 +1092,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}.getMessage
// If you use backticks to quote the name then it's OK.
- sqlContext.sql(
+ spark.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
|USING parquet
@@ -1100,12 +1100,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
| path '$path'
|)
""".stripMargin)
- checkAnswer(sqlContext.table("`db.t`"), df)
+ checkAnswer(spark.table("`db.t`"), df)
}
}
test("SPARK-10593 same column names in lateral view") {
- val df = sqlContext.sql(
+ val df = spark.sql(
"""
|select
|insideLayer2.json as a2
@@ -1120,7 +1120,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
ignore("SPARK-10310: " +
"script transformation using default input/output SerDe and record reader/writer") {
- sqlContext
+ spark
.range(5)
.selectExpr("id AS a", "id AS b")
.registerTempTable("test")
@@ -1138,7 +1138,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
ignore("SPARK-10310: script transformation using LazySimpleSerDe") {
- sqlContext
+ spark
.range(5)
.selectExpr("id AS a", "id AS b")
.registerTempTable("test")
@@ -1183,7 +1183,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("run sql directly on files") {
- val df = sqlContext.range(100).toDF()
+ val df = spark.range(100).toDF()
withTempPath(f => {
df.write.parquet(f.getCanonicalPath)
checkAnswer(sql(s"select id from parquet.`${f.getCanonicalPath}`"),
@@ -1325,14 +1325,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
Seq("3" -> "30").toDF("i", "j")
.write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453")
checkAnswer(
- sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"),
+ spark.read.table("tbl11453").select("i", "j").orderBy("i"),
Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil)
// make sure case sensitivity is correct.
Seq("4" -> "40").toDF("i", "j")
.write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453")
checkAnswer(
- sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"),
+ spark.read.table("tbl11453").select("i", "j").orderBy("i"),
Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil)
}
}
@@ -1370,7 +1370,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("multi-insert with lateral view") {
withTempTable("t1") {
- sqlContext.range(10)
+ spark.range(10)
.select(array($"id", $"id" + 1).as("arr"), $"id")
.registerTempTable("source")
withTable("dest1", "dest2") {
@@ -1388,10 +1388,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin)
checkAnswer(
- sqlContext.table("dest1"),
+ spark.table("dest1"),
sql("SELECT id FROM source WHERE id > 3"))
checkAnswer(
- sqlContext.table("dest2"),
+ spark.table("dest2"),
sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3"))
}
}
@@ -1404,7 +1404,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
withTempPath { dir =>
withTempTable("t1", "t2") {
val path = dir.getCanonicalPath
- val ds = sqlContext.range(10)
+ val ds = spark.range(10)
ds.registerTempTable("t1")
sql(
@@ -1415,7 +1415,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin)
checkAnswer(
- sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
+ spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"),
Row(true)
)
@@ -1429,7 +1429,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
"shouldn always be used together with PATH data source option"
) {
withTempTable("t") {
- sqlContext.range(10).registerTempTable("t")
+ spark.range(10).registerTempTable("t")
val message = intercept[IllegalArgumentException] {
sql(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala
index 72f9fba13d..f37037e3c7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala
@@ -30,11 +30,11 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
override def beforeAll(): Unit = {
// Create a simple table with two columns: id and id1
- sqlContext.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt")
+ spark.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt")
}
override def afterAll(): Unit = {
- sqlContext.sql(s"DROP TABLE IF EXISTS jt")
+ spark.sql(s"DROP TABLE IF EXISTS jt")
}
test("nested views (interleaved with temporary views)") {
@@ -277,11 +277,11 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
withSQLConf(
SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") {
withTable("add_col") {
- sqlContext.range(10).write.saveAsTable("add_col")
+ spark.range(10).write.saveAsTable("add_col")
withView("v") {
sql("CREATE VIEW v AS SELECT * FROM add_col")
- sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col")
- checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10).toDF())
+ spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col")
+ checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF())
}
}
}
@@ -291,8 +291,8 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
// make sure the new flag can handle some complex cases like join and schema change.
withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") {
withTable("jt1", "jt2") {
- sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1")
- sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2")
+ spark.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1")
+ spark.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2")
sql("CREATE VIEW testView AS SELECT * FROM jt1 JOIN jt2 ON id1 == id2")
checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i)))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
index d0e7552c12..cbbeacf6ad 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala
@@ -353,7 +353,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi
checkAnswer(actual, expected)
- sqlContext.dropTempTable("nums")
+ spark.catalog.dropTempTable("nums")
}
test("SPARK-7595: Window will cause resolve failed with self join") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
index b97da1ffdc..965680ff0d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -75,11 +75,11 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
(1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path)
checkAnswer(
- sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"),
+ spark.read.orc(path).where("not (a = 2) or not(b in ('1'))"),
(1 to 5).map(i => Row(i, (i % 2).toString)))
checkAnswer(
- sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"),
+ spark.read.orc(path).where("not (a = 2 and b in ('1'))"),
(1 to 5).map(i => Row(i, (i % 2).toString)))
}
}
@@ -94,7 +94,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
.orc(path)
// Check if this is compressed as ZLIB.
- val conf = sqlContext.sessionState.newHadoopConf()
+ val conf = spark.sessionState.newHadoopConf()
val fs = FileSystem.getLocal(conf)
val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc"))
assert(maybeOrcFile.isDefined)
@@ -102,7 +102,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
val orcReader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf))
assert(orcReader.getCompression == CompressionKind.ZLIB)
- val copyDf = sqlContext
+ val copyDf = spark
.read
.orc(path)
checkAnswer(df, copyDf)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index aa9c1189db..084546f99d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -66,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withOrcFile(data) { file =>
checkAnswer(
- sqlContext.read.orc(file),
+ spark.read.orc(file),
data.toDF().collect())
}
}
@@ -170,7 +170,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
// Hive supports zlib, snappy and none for Hive 1.2.1.
test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") {
withTempPath { file =>
- sqlContext.range(0, 10).write
+ spark.range(0, 10).write
.option("orc.compress", "ZLIB")
.orc(file.getCanonicalPath)
val expectedCompressionKind =
@@ -179,7 +179,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
withTempPath { file =>
- sqlContext.range(0, 10).write
+ spark.range(0, 10).write
.option("orc.compress", "SNAPPY")
.orc(file.getCanonicalPath)
val expectedCompressionKind =
@@ -188,7 +188,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
withTempPath { file =>
- sqlContext.range(0, 10).write
+ spark.range(0, 10).write
.option("orc.compress", "NONE")
.orc(file.getCanonicalPath)
val expectedCompressionKind =
@@ -200,7 +200,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
// Following codec is not supported in Hive 1.2.1, ignore it now
ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") {
withTempPath { file =>
- sqlContext.range(0, 10).write
+ spark.range(0, 10).write
.option("orc.compress", "LZO")
.orc(file.getCanonicalPath)
val expectedCompressionKind =
@@ -301,12 +301,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path)
- sqlContext.read.format("orc").load(path).schema("Acol")
+ spark.range(0, 10).select('id as "Acol").write.format("orc").save(path)
+ spark.read.format("orc").load(path).schema("Acol")
intercept[IllegalArgumentException] {
- sqlContext.read.format("orc").load(path).schema("acol")
+ spark.read.format("orc").load(path).schema("acol")
}
- checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"),
+ checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"),
(0 until 10).map(Row(_)))
}
}
@@ -317,7 +317,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTable("empty_orc") {
withTempTable("empty", "single") {
- sqlContext.sql(
+ spark.sql(
s"""CREATE TABLE empty_orc(key INT, value STRING)
|STORED AS ORC
|LOCATION '$path'
@@ -328,13 +328,13 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
// This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because
// Spark SQL ORC data source always avoids write empty ORC files.
- sqlContext.sql(
+ spark.sql(
s"""INSERT INTO TABLE empty_orc
|SELECT key, value FROM empty
""".stripMargin)
val errorMessage = intercept[AnalysisException] {
- sqlContext.read.orc(path)
+ spark.read.orc(path)
}.getMessage
assert(errorMessage.contains("Unable to infer schema for ORC"))
@@ -342,12 +342,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1)
singleRowDF.registerTempTable("single")
- sqlContext.sql(
+ spark.sql(
s"""INSERT INTO TABLE empty_orc
|SELECT key, value FROM single
""".stripMargin)
- val df = sqlContext.read.orc(path)
+ val df = spark.read.orc(path)
assert(df.schema === singleRowDF.schema.asNullable)
checkAnswer(df, singleRowDF)
}
@@ -373,7 +373,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path)
- val df = sqlContext.read.orc(path)
+ val df = spark.read.orc(path)
def checkPredicate(pred: Column, answer: Seq[Row]): Unit = {
val sourceDf = stripSparkFilter(df.where(pred))
@@ -415,7 +415,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withTable("dummy_orc") {
withTempTable("single") {
- sqlContext.sql(
+ spark.sql(
s"""CREATE TABLE dummy_orc(key INT, value STRING)
|STORED AS ORC
|LOCATION '$path'
@@ -424,12 +424,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1)
singleRowDF.registerTempTable("single")
- sqlContext.sql(
+ spark.sql(
s"""INSERT INTO TABLE dummy_orc
|SELECT key, value FROM single
""".stripMargin)
- val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0")
+ val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0")
checkAnswer(df, singleRowDF)
val queryExecution = df.queryExecution
@@ -448,7 +448,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
val data = (0 until 10).map(i => Tuple1(Array(i)))
withOrcFile(data) { file =>
- val actual = sqlContext
+ val actual = spark
.read
.orc(file)
.where("_1 is not null")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index 637c10611a..aba60da33f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -49,7 +49,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withOrcFile(data)(path => f(sqlContext.read.orc(path)))
+ withOrcFile(data)(path => f(spark.read.orc(path)))
}
/**
@@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
- sqlContext.registerDataFrameAsTable(df, tableName)
+ spark.wrapped.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 1c1f6d910d..6e93bbde26 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -634,7 +634,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
""".stripMargin)
checkAnswer(
- sqlContext.read.parquet(path),
+ spark.read.parquet(path),
Row("1st", "2nd", Seq(Row("val_a", "val_b"))))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index a3e7737a7c..8bf6f224a4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -105,7 +105,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
}
// Read the bucket file into a dataframe, so that it's easier to test.
- val readBack = sqlContext.read.format(source)
+ val readBack = spark.read.format(source)
.load(bucketFile.getAbsolutePath)
.select(columns: _*)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index 08e83b7f69..f9387fae4a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -34,7 +34,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
// Here we coalesce partition number to 1 to ensure that only a single task is issued. This
// prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
// directory while committing/aborting the job. See SPARK-8513 for more details.
- val df = sqlContext.range(0, 10).coalesce(1)
+ val df = spark.range(0, 10).coalesce(1)
intercept[SparkException] {
df.write.format(dataSourceName).save(file.getCanonicalPath)
}
@@ -49,7 +49,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
withTempPath { file =>
// fail the job in the middle of writing
val divideByZero = udf((x: Int) => { x / (x - 1)})
- val df = sqlContext.range(0, 10).coalesce(1).select(divideByZero(col("id")))
+ val df = spark.range(0, 10).coalesce(1).select(divideByZero(col("id")))
SimpleTextRelation.callbackCalled = false
intercept[SparkException] {
@@ -66,7 +66,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
SimpleTextRelation.failCommitter = false
withTempPath { file =>
// fail the job in the middle of writing
- val df = sqlContext.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id"))
+ val df = spark.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id"))
SimpleTextRelation.callbackCalled = false
SimpleTextRelation.failWriter = true
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
index 20c5f72ff1..f4d63334b6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton {
- import sqlContext.implicits._
+ import spark.implicits._
val dataSourceName: String
@@ -143,8 +143,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.add("index", IntegerType, nullable = false)
.add("col", dataType, nullable = true)
val rdd =
- sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator())))
- val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
+ spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator())))
+ val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
df.write
.mode("overwrite")
@@ -153,7 +153,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.options(extraOptions)
.save(path)
- val loadedDF = sqlContext
+ val loadedDF = spark
.read
.format(dataSourceName)
.option("dataSchema", df.schema.json)
@@ -174,7 +174,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath)
checkAnswer(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("path", file.getCanonicalPath)
.option("dataSchema", dataSchema.json)
.load(),
@@ -188,7 +188,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath)
checkAnswer(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath).orderBy("a"),
testDF.union(testDF).orderBy("a").collect())
@@ -208,7 +208,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath)
val path = new Path(file.getCanonicalPath)
- val fs = path.getFileSystem(sqlContext.sessionState.newHadoopConf())
+ val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
assert(fs.listStatus(path).isEmpty)
}
}
@@ -222,7 +222,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.save(file.getCanonicalPath)
checkQueries(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath))
}
@@ -243,7 +243,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.save(file.getCanonicalPath)
checkAnswer(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.collect())
@@ -265,7 +265,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.save(file.getCanonicalPath)
checkAnswer(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.union(partitionedTestDF).collect())
@@ -287,7 +287,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.save(file.getCanonicalPath)
checkAnswer(
- sqlContext.read.format(dataSourceName)
+ spark.read.format(dataSourceName)
.option("dataSchema", dataSchema.json)
.load(file.getCanonicalPath),
partitionedTestDF.collect())
@@ -323,7 +323,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t"), testDF.collect())
+ checkAnswer(spark.table("t"), testDF.collect())
}
}
@@ -332,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect())
+ checkAnswer(spark.table("t"), testDF.union(testDF).orderBy("a").collect())
}
}
@@ -351,7 +351,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
withTempTable("t") {
testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t")
- assert(sqlContext.table("t").collect().isEmpty)
+ assert(spark.table("t").collect().isEmpty)
}
}
@@ -362,18 +362,18 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkQueries(sqlContext.table("t"))
+ checkQueries(spark.table("t"))
}
}
test("saveAsTable()/load() - partitioned table - boolean type") {
- sqlContext.range(2)
+ spark.range(2)
.select('id, ('id % 2 === 0).as("b"))
.write.partitionBy("b").saveAsTable("t")
withTable("t") {
checkAnswer(
- sqlContext.table("t").sort('id),
+ spark.table("t").sort('id),
Row(0, true) :: Row(1, false) :: Nil
)
}
@@ -395,7 +395,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t"), partitionedTestDF.collect())
+ checkAnswer(spark.table("t"), partitionedTestDF.collect())
}
}
@@ -415,7 +415,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect())
+ checkAnswer(spark.table("t"), partitionedTestDF.union(partitionedTestDF).collect())
}
}
@@ -435,7 +435,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t"), partitionedTestDF.collect())
+ checkAnswer(spark.table("t"), partitionedTestDF.collect())
}
}
@@ -484,7 +484,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.partitionBy("p1", "p2")
.saveAsTable("t")
- assert(sqlContext.table("t").collect().isEmpty)
+ assert(spark.table("t").collect().isEmpty)
}
}
@@ -516,7 +516,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
// Inferring schema should throw error as it should not find any file to infer
val e = intercept[Exception] {
- sqlContext.read.format(dataSourceName).load(dir.getCanonicalPath)
+ spark.read.format(dataSourceName).load(dir.getCanonicalPath)
}
e match {
@@ -533,7 +533,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
/** Test whether data is read with the given path matches the expected answer */
def testWithPath(path: File, expectedAnswer: Seq[Row]): Unit = {
- val df = sqlContext.read
+ val df = spark.read
.format(dataSourceName)
.schema(dataInDir.schema) // avoid schema inference for any format
.load(path.getCanonicalPath)
@@ -618,7 +618,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
/** Check whether data is read with the given path matches the expected answer */
def check(path: String, expectedDf: DataFrame): Unit = {
- val df = sqlContext.read
+ val df = spark.read
.format(dataSourceName)
.schema(schema) // avoid schema inference for any format, expected to be same format
.load(path)
@@ -654,7 +654,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
basePath: Option[String] = None
): Unit = {
try {
- val reader = sqlContext.read
+ val reader = spark.read
basePath.foreach(reader.option("basePath", _))
val testDf = reader
.format(dataSourceName)
@@ -739,7 +739,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
val realData = input.collect()
- checkAnswer(sqlContext.table("t"), realData ++ realData)
+ checkAnswer(spark.table("t"), realData ++ realData)
}
}
}
@@ -754,7 +754,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.saveAsTable("t")
withTable("t") {
- checkAnswer(sqlContext.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect())
+ checkAnswer(spark.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect())
}
}
@@ -766,7 +766,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
test("SPARK-8406: Avoids name collision while writing files") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext
+ spark
.range(10000)
.repartition(250)
.write
@@ -775,7 +775,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.save(path)
assertResult(10000) {
- sqlContext
+ spark
.read
.format(dataSourceName)
.option("dataSchema", StructType(StructField("id", LongType) :: Nil).json)
@@ -794,7 +794,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
classOf[AlwaysFailParquetOutputCommitter].getName
)
- val df = sqlContext.range(1, 10).toDF("i")
+ val df = spark.range(1, 10).toDF("i")
withTempPath { dir =>
df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
// Because there data already exists,
@@ -802,7 +802,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
// with file format and AlwaysFailOutputCommitter will not be used.
df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
checkAnswer(
- sqlContext.read
+ spark.read
.format(dataSourceName)
.option("dataSchema", df.schema.json)
.options(extraOptions)
@@ -850,12 +850,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
)
withTempPath { dir =>
val path = "file://" + dir.getCanonicalPath
- val df1 = sqlContext.range(4)
+ val df1 = spark.range(4)
df1.coalesce(1).write.mode("overwrite").options(options).format(dataSourceName).save(path)
df1.coalesce(1).write.mode("append").options(options).format(dataSourceName).save(path)
def checkLocality(): Unit = {
- val df2 = sqlContext.read
+ val df2 = spark.read
.format(dataSourceName)
.option("dataSchema", df1.schema.json)
.options(options)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
index 1d104889fe..4b4852c1d7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -126,18 +126,18 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
test("SPARK-8604: Parquet data source should write summary file while doing appending") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df = sqlContext.range(0, 5).toDF()
+ val df = spark.range(0, 5).toDF()
df.write.mode(SaveMode.Overwrite).parquet(path)
val summaryPath = new Path(path, "_metadata")
val commonSummaryPath = new Path(path, "_common_metadata")
- val fs = summaryPath.getFileSystem(sqlContext.sessionState.newHadoopConf())
+ val fs = summaryPath.getFileSystem(spark.sessionState.newHadoopConf())
fs.delete(summaryPath, true)
fs.delete(commonSummaryPath, true)
df.write.mode(SaveMode.Append).parquet(path)
- checkAnswer(sqlContext.read.parquet(path), df.union(df))
+ checkAnswer(spark.read.parquet(path), df.union(df))
assert(fs.exists(summaryPath))
assert(fs.exists(commonSummaryPath))
@@ -148,8 +148,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
withTempPath { dir =>
val path = dir.getCanonicalPath
- sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path)
- val df = sqlContext.read.parquet(path).filter('a === 0).select('b)
+ spark.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path)
+ val df = spark.read.parquet(path).filter('a === 0).select('b)
val physicalPlan = df.queryExecution.sparkPlan
assert(physicalPlan.collect { case p: execution.ProjectExec => p }.length === 1)
@@ -170,7 +170,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
// The schema consists of the leading columns of the first part-file
// in the lexicographic order.
- assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name)
+ assert(spark.read.parquet(dir.getCanonicalPath).schema.map(_.name)
=== Seq("a", "b", "c", "d", "part"))
}
}
@@ -188,8 +188,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
Row(5, 127.toByte), Row(6, -44.toByte), Row(7, 23.toByte), Row(8, -95.toByte),
Row(9, 127.toByte), Row(10, 13.toByte))
- val rdd = sqlContext.sparkContext.parallelize(data)
- val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
+ val rdd = spark.sparkContext.parallelize(data)
+ val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
df.write
.mode("overwrite")
@@ -197,7 +197,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
.option("dataSchema", df.schema.json)
.save(path)
- val loadedDF = sqlContext
+ val loadedDF = spark
.read
.format(dataSourceName)
.option("dataSchema", df.schema.json)
@@ -221,7 +221,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
val compressedFiles = new File(path).listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".gz.parquet")))
- val copyDf = sqlContext
+ val copyDf = spark
.read
.parquet(path)
checkAnswer(df, copyDf)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index 9ad0887609..fa64c7dcfa 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -69,7 +69,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
test("test hadoop conf option propagation") {
withTempPath { file =>
// Test write side
- val df = sqlContext.range(10).selectExpr("cast(id as string)")
+ val df = spark.range(10).selectExpr("cast(id as string)")
df.write
.option("some-random-write-option", "hahah-WRITE")
.option("some-null-value-option", null) // test null robustness
@@ -78,7 +78,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
assert(SimpleTextRelation.lastHadoopConf.get.get("some-random-write-option") == "hahah-WRITE")
// Test read side
- val df1 = sqlContext.read
+ val df1 = spark.read
.option("some-random-read-option", "hahah-READ")
.option("some-null-value-option", null) // test null robustness
.option("dataSchema", df.schema.json)