aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-05 13:57:21 -0700
committerReynold Xin <rxin@databricks.com>2015-06-05 13:57:21 -0700
commit6ebe419f335fcfb66dd3da74baf35eb5b2fc061d (patch)
treeba0c4e82bdb819bc1229942fa8a26211622cae34 /sql
parent356a4a9b93a1eeedb910c6bccc0abadf59e4877f (diff)
downloadspark-6ebe419f335fcfb66dd3da74baf35eb5b2fc061d.tar.gz
spark-6ebe419f335fcfb66dd3da74baf35eb5b2fc061d.tar.bz2
spark-6ebe419f335fcfb66dd3da74baf35eb5b2fc061d.zip
[SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd.
Fixed the following packages: sql.columnar sql.jdbc sql.json sql.parquet Author: Reynold Xin <rxin@databricks.com> Closes #6667 from rxin/testsqlcontext_wildcard and squashes the following commits: 134a776 [Reynold Xin] Fixed compilation break. 6da7b69 [Reynold Xin] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala45
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala75
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala14
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala8
15 files changed, 234 insertions, 245 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 055453e688..fa3b8144c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
@@ -31,8 +29,12 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.{logicalPlanToSparkQuery, sql}
+
test("simple columnar query") {
- val plan = executePlan(testData.logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -40,16 +42,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
- cacheTable("sizeTst")
+ ctx.cacheTable("sizeTst")
assert(
- table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
- conf.autoBroadcastJoinThreshold)
+ ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
+ ctx.conf.autoBroadcastJoinThreshold)
}
test("projection") {
- val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
@@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
- val plan = executePlan(testData.logicalPlan).executedPlan
+ val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
- cacheTable("repeatedData")
+ ctx.cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
@@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
- cacheTable("nullableRepeatedData")
+ ctx.cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
@@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq.map(Row.fromTuple))
- cacheTable("timestamps")
+ ctx.cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
@@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
- cacheTable("withEmptyParts")
+ ctx.cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
@@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Create a RDD for the schema
val rdd =
- sparkContext.parallelize((1 to 100), 10).map { i =>
+ ctx.sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
@@ -175,18 +177,18 @@ class InMemoryColumnarQuerySuite extends QueryTest {
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
- createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
+ ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
- val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan
+ val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
- isCached("InMemoryCache_different_data_types"),
+ ctx.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
- table("InMemoryCache_different_data_types").collect())
- dropTempTable("InMemoryCache_different_data_types")
+ ctx.table("InMemoryCache_different_data_types").collect())
+ ctx.dropTempTable("InMemoryCache_different_data_types")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index cda1b0992e..6545c6b314 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -21,40 +21,42 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
- val originalColumnBatchSize = conf.columnBatchSize
- val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
+
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+
+ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
+ private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
- setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
- val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+ val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
- setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
// Enable in-memory table scan accumulators
- setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
+ ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
- setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
- setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
before {
- cacheTable("pruningData")
+ ctx.cacheTable("pruningData")
}
after {
- uncacheTable("pruningData")
+ ctx.uncacheTable("pruningData")
}
// Comparisons
@@ -108,7 +110,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
- val df = sql(query)
+ val df = ctx.sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index e20c66cb2f..7931854db2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -21,13 +21,11 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.test._
-import org.apache.spark.sql.types._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
-import TestSQLContext._
-import TestSQLContext.implicits._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
@@ -37,12 +35,16 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
val testH2Dialect = new JdbcDialect {
- def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
Some(StringType)
}
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.sql
+
before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
@@ -253,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("Basic API") {
- assert(TestSQLContext.read.jdbc(
+ assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
- assert(TestSQLContext.read.jdbc(
+ assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
+ ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
+ assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}
@@ -328,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test DATE types") {
- val rows = TestSQLContext.read.jdbc(
+ val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
@@ -338,9 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test DATE types in cache") {
- val rows =
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
+ ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
@@ -348,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test types for null value") {
- val rows = TestSQLContext.read.jdbc(
+ val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
@@ -395,10 +396,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
- val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
- assert(df.schema.filter(
- _.dataType != org.apache.spark.sql.types.StringType
- ).isEmpty)
+ val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
assert(rows(0).get(1).isInstanceOf[String])
@@ -419,7 +418,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("Aggregated dialects") {
val agg = new AggregatedDialect(List(new JdbcDialect {
- def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
@@ -430,8 +429,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
- assert(agg.getCatalystType(0, "", 1, null) == Some(LongType))
- assert(agg.getCatalystType(1, "", 1, null) == Some(StringType))
+ assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
+ assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 2de8c1a609..d949ef4226 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -24,7 +24,6 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
-import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
@@ -37,6 +36,10 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
+ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.implicits._
+ import ctx.sql
+
before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
@@ -54,14 +57,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
- TestSQLContext.sql(
+ ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
- TestSQLContext.sql(
+ ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
@@ -74,66 +77,64 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
conn1.close()
}
- val sc = TestSQLContext.sparkContext
+ private lazy val sc = ctx.sparkContext
- val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
- val arr1x2 = Array[Row](Row.apply("fred", 3))
- val schema2 = StructType(
+ private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
+ private lazy val arr1x2 = Array[Row](Row.apply("fred", 3))
+ private lazy val schema2 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) :: Nil)
- val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
- val schema3 = StructType(
+ private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
+ private lazy val schema3 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
- assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
- assert(2 ==
- TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
+ assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
+ assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
test("CREATE with overwrite") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
- assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
- assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
- assert(2 ==
- TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
+ assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
+ assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
- assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
+ assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {
- val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
intercept[org.apache.spark.SparkException] {
@@ -142,15 +143,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
- TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
- TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index f8d62f9e7e..d889c7be17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp}
import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.util.Utils
-class JsonSuite extends QueryTest {
- import org.apache.spark.sql.json.TestJsonData._
+class JsonSuite extends QueryTest with TestJsonData {
- TestJsonData
+ protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+ import ctx.sql
+ import ctx.implicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
@@ -214,7 +212,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring with null in sampling") {
- val jsonDF = read.json(jsonNullStruct)
+ val jsonDF = ctx.read.json(jsonNullStruct)
val expectedSchema = StructType(
StructField("headers", StructType(
StructField("Charset", StringType, true) ::
@@ -233,7 +231,7 @@ class JsonSuite extends QueryTest {
}
test("Primitive field and type inferring") {
- val jsonDF = read.json(primitiveFieldAndType)
+ val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
@@ -261,7 +259,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring") {
- val jsonDF = read.json(complexFieldAndType1)
+ val jsonDF = ctx.read.json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
@@ -360,7 +358,7 @@ class JsonSuite extends QueryTest {
}
test("GetField operation on complex data type") {
- val jsonDF = read.json(complexFieldAndType1)
+ val jsonDF = ctx.read.json(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -376,7 +374,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in primitive field values") {
- val jsonDF = read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
@@ -450,7 +448,7 @@ class JsonSuite extends QueryTest {
}
ignore("Type conflict in primitive field values (Ignored)") {
- val jsonDF = read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expression.
@@ -503,7 +501,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in complex field values") {
- val jsonDF = read.json(complexFieldValueTypeConflict)
+ val jsonDF = ctx.read.json(complexFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("array", ArrayType(LongType, true), true) ::
@@ -527,7 +525,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in array elements") {
- val jsonDF = read.json(arrayElementTypeConflict)
+ val jsonDF = ctx.read.json(arrayElementTypeConflict)
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
@@ -555,7 +553,7 @@ class JsonSuite extends QueryTest {
}
test("Handling missing fields") {
- val jsonDF = read.json(missingFields)
+ val jsonDF = ctx.read.json(missingFields)
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
@@ -574,8 +572,9 @@ class JsonSuite extends QueryTest {
val dir = Utils.createTempDir()
dir.delete()
val path = dir.getCanonicalPath
- sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
- val jsonDF = read.option("samplingRatio", "0.49").json(path)
+ ctx.sparkContext.parallelize(1 to 100)
+ .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
+ val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path)
val analyzed = jsonDF.queryExecution.analyzed
assert(
@@ -590,7 +589,7 @@ class JsonSuite extends QueryTest {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
- read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
+ ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.path === Some(path))
assert(relationWithSchema.schema === schema)
@@ -602,7 +601,7 @@ class JsonSuite extends QueryTest {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = read.json(path)
+ val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
@@ -671,7 +670,7 @@ class JsonSuite extends QueryTest {
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
- val jsonDF1 = read.schema(schema).json(path)
+ val jsonDF1 = ctx.read.schema(schema).json(path)
assert(schema === jsonDF1.schema)
@@ -688,7 +687,7 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
- val jsonDF2 = read.schema(schema).json(primitiveFieldAndType)
+ val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType)
assert(schema === jsonDF2.schema)
@@ -709,7 +708,7 @@ class JsonSuite extends QueryTest {
test("Applying schemas with MapType") {
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1)
+ val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
@@ -737,7 +736,7 @@ class JsonSuite extends QueryTest {
val schemaWithComplexMap = StructType(
StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
- val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2)
+ val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2)
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
@@ -763,7 +762,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-2096 Correctly parse dot notations") {
- val jsonDF = read.json(complexFieldAndType2)
+ val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -781,7 +780,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3390 Complex arrays") {
- val jsonDF = read.json(complexFieldAndType2)
+ val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -804,7 +803,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3308 Read top level JSON arrays") {
- val jsonDF = read.json(jsonArray)
+ val jsonDF = ctx.read.json(jsonArray)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -822,10 +821,10 @@ class JsonSuite extends QueryTest {
test("Corrupt records") {
// Test if we can query corrupt records.
- val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+ val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
- val jsonDF = read.json(corruptRecords)
+ val jsonDF = ctx.read.json(corruptRecords)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -875,11 +874,11 @@ class JsonSuite extends QueryTest {
Row("]") :: Nil
)
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
test("SPARK-4068: nulls in arrays") {
- val jsonDF = read.json(nullsInArrays)
+ val jsonDF = ctx.read.json(nullsInArrays)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -925,7 +924,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = createDataFrame(rowRDD1, schema1)
+ val df1 = ctx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDF
val result = df2.toJSON.collect()
@@ -948,7 +947,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = createDataFrame(rowRDD2, schema2)
+ val df3 = ctx.createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDF
val result2 = df4.toJSON.collect()
@@ -956,8 +955,8 @@ class JsonSuite extends QueryTest {
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
- val jsonDF = read.json(primitiveFieldAndType)
- val primTable = read.json(jsonDF.toJSON)
+ val jsonDF = ctx.read.json(primitiveFieldAndType)
+ val primTable = ctx.read.json(jsonDF.toJSON)
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
@@ -969,8 +968,8 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
- val complexJsonDF = read.json(complexFieldAndType1)
- val compTable = read.json(complexJsonDF.toJSON)
+ val complexJsonDF = ctx.read.json(complexFieldAndType1)
+ val compTable = ctx.read.json(complexJsonDF.toJSON)
compTable.registerTempTable("complexTable")
// Access elements of a primitive array.
checkAnswer(
@@ -1074,29 +1073,29 @@ class JsonSuite extends QueryTest {
}
test("SPARK-7565 MapType in JsonRDD") {
- val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
- val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
- TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+ val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
+ val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
try{
for (useStreaming <- List("true", "false")) {
- setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
val temp = Utils.createTempDir().getPath
- val df = read.schema(schemaWithSimpleMap).json(mapType1)
+ val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
df.write.mode("overwrite").parquet(temp)
// order of MapType is not defined
- assert(read.parquet(temp).count() == 5)
+ assert(ctx.read.parquet(temp).count() == 5)
- val df2 = read.json(corruptRecords)
+ val df2 = ctx.read.json(corruptRecords)
df2.write.mode("overwrite").parquet(temp)
- checkAnswer(read.parquet(temp), df2.collect())
+ checkAnswer(ctx.read.parquet(temp), df2.collect())
}
} finally {
- setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
- setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
+ ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index 47a97a49da..b6a6a8dc6a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.json
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
-object TestJsonData {
+trait TestJsonData {
- val primitiveFieldAndType =
- TestSQLContext.sparkContext.parallelize(
+ protected def ctx: SQLContext
+
+ def primitiveFieldAndType: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@@ -32,8 +35,8 @@ object TestJsonData {
"null":null
}""" :: Nil)
- val primitiveFieldValueTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def primitiveFieldValueTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -43,15 +46,15 @@ object TestJsonData {
"""{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470,
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
- val jsonNullStruct =
- TestSQLContext.sparkContext.parallelize(
+ def jsonNullStruct: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
- val complexFieldValueTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldValueTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@@ -61,23 +64,23 @@ object TestJsonData {
"""{"num_struct":{}, "str_array":["str1", "str2", 33],
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
- val arrayElementTypeConflict =
- TestSQLContext.sparkContext.parallelize(
+ def arrayElementTypeConflict: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
- val missingFields =
- TestSQLContext.sparkContext.parallelize(
+ def missingFields: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
"""{"d":{"field":true}}""" ::
"""{"e":"str"}""" :: Nil)
- val complexFieldAndType1 =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldAndType1: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@@ -92,8 +95,8 @@ object TestJsonData {
"arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]]
}""" :: Nil)
- val complexFieldAndType2 =
- TestSQLContext.sparkContext.parallelize(
+ def complexFieldAndType2: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@@ -146,16 +149,16 @@ object TestJsonData {
]]
}""" :: Nil)
- val mapType1 =
- TestSQLContext.sparkContext.parallelize(
+ def mapType1: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
"""{"map": {"c": 1, "d": 4}}""" ::
"""{"map": {"e": null}}""" :: Nil)
- val mapType2 =
- TestSQLContext.sparkContext.parallelize(
+ def mapType2: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -163,22 +166,22 @@ object TestJsonData {
"""{"map": {"e": null}}""" ::
"""{"map": {"f": {"field1": null}}}""" :: Nil)
- val nullsInArrays =
- TestSQLContext.sparkContext.parallelize(
+ def nullsInArrays: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
- val jsonArray =
- TestSQLContext.sparkContext.parallelize(
+ def jsonArray: RDD[String] =
+ ctx.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
- val corruptRecords =
- TestSQLContext.sparkContext.parallelize(
+ def corruptRecords: RDD[String] =
+ ctx.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@@ -186,6 +189,5 @@ object TestJsonData {
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""]""" :: Nil)
- val empty =
- TestSQLContext.sparkContext.parallelize(Seq[String]())
+ def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index 4aa5bcb7fd..17f5f9a491 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.sources.LogicalRelation
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
@@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
* data type is nullable.
*/
class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
private def checkFilterPredicate(
df: DataFrame,
@@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA
}
class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 7f7c2cc1a6..2b6a27032e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -36,9 +36,6 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode}
@@ -66,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
-
- import sqlContext.implicits.localSeqToDataFrameHolder
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
@@ -104,7 +100,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
test("fixed-length decimals") {
def makeDecimalRDD(decimal: DecimalType): DataFrame =
- sparkContext
+ sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
.toDF()
@@ -115,7 +111,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
@@ -123,7 +119,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
- read.parquet(dir.getCanonicalPath).collect()
+ sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
@@ -131,14 +127,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
- read.parquet(dir.getCanonicalPath).collect()
+ sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
}
test("date type") {
def makeDateRDD(): DataFrame =
- sparkContext
+ sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(DateUtils.toJavaDate(i)))
.toDF()
@@ -147,7 +143,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDateRDD()
data.write.parquet(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
@@ -236,7 +232,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
def checkCompressionCodec(codec: CompressionCodecName): Unit = {
withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
withParquetFile(data) { path =>
- assertResult(conf.parquetCompressionCodec.toUpperCase) {
+ assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) {
compressionCodecFor(path)
}
}
@@ -244,7 +240,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
// Checks default compression codec
- checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec))
+ checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec))
checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
checkCompressionCodec(CompressionCodecName.GZIP)
@@ -283,7 +279,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
- checkAnswer(read.parquet(path.toString), (0 until 10).map { i =>
+ checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i =>
Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
})
}
@@ -312,7 +308,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file)
- checkAnswer(read.parquet(file), newData.map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple))
}
}
@@ -321,7 +317,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file)
- checkAnswer(read.parquet(file), data.map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple))
}
}
@@ -341,7 +337,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file)
- checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple))
+ checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple))
}
}
@@ -369,11 +365,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val path = new Path(location.getCanonicalPath)
ParquetFileWriter.writeMetadataFile(
- sparkContext.hadoopConfiguration,
+ sqlContext.sparkContext.hadoopConfiguration,
path,
new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil)
- assertResult(read.parquet(path.toString).schema) {
+ assertResult(sqlContext.read.parquet(path.toString).schema) {
StructType(
StructField("a", BooleanType, nullable = false) ::
StructField("b", IntegerType, nullable = false) ::
@@ -406,7 +402,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -430,7 +426,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA
}
class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index 3b29979452..8979a0a210 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql.parquet
import java.io.File
@@ -28,7 +29,6 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.sources.PartitioningUtils._
import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec}
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext}
@@ -39,10 +39,10 @@ case class ParquetData(intField: Int, stringField: String)
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
- override val sqlContext: SQLContext = TestSQLContext
- import sqlContext._
+ override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
+ import sqlContext.sql
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
@@ -190,8 +190,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
// Introduce _temporary dir to the base dir the robustness of the schema discovery process.
new File(base.getCanonicalPath, "_temporary").mkdir()
- println("load the partitioned table")
- read.parquet(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -238,7 +237,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- read.parquet(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -286,7 +285,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
+ val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -326,7 +325,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
- val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
+ val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
@@ -358,7 +357,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
(1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"),
makePartitionDir(base, defaultPartitionName, "pi" -> 2))
- read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t")
+ sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
@@ -371,7 +370,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
test("SPARK-7749 Non-partitioned table should have empty partition spec") {
withTempPath { dir =>
(1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath)
- val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution
+ val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution
queryExecution.analyzed.collectFirst {
case LogicalRelation(relation: ParquetRelation2) =>
assert(relation.partitionSpec === PartitionSpec.emptySpec)
@@ -385,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
withTempPath { dir =>
val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s")
df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath)
- checkAnswer(read.parquet(dir.getCanonicalPath), df.collect())
+ checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect())
}
}
@@ -425,12 +424,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
}
val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
- val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+ val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema)
withTempPath { dir =>
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
val fields = schema.map(f => Column(f.name).cast(f.dataType))
- checkAnswer(read.load(dir.toString).select(fields: _*), row)
+ checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row)
}
}
@@ -446,7 +445,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store"))
Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
- checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df)
+ checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 304936fb2b..de0107a361 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -22,14 +22,14 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
/**
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+ import sqlContext.implicits._
+ import sqlContext.sql
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -40,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
- catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(table("t"), data.map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
- catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@@ -118,7 +118,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
- val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema)
+ val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
val df2 = sqlContext.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
@@ -127,7 +127,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
@@ -139,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd
}
class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
- val originalConf = sqlContext.conf.parquetUseDataSourceApi
+ private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
index 8b1745124b..171a656f0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
@@ -24,11 +24,10 @@ import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
- val sqlContext = TestSQLContext
+ lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 516ba373f4..eb15a1609f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils {
- import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder}
- import sqlContext.sparkContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
@@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
- data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 17a8b0cca0..ac4a00a6f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
trait SQLTestUtils {
- val sqlContext: SQLContext
+ def sqlContext: SQLContext
- import sqlContext.{conf, sparkContext}
-
- protected def configuration = sparkContext.hadoopConfiguration
+ protected def configuration = sqlContext.sparkContext.hadoopConfiguration
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@@ -39,12 +37,12 @@ trait SQLTestUtils {
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
- (keys, values).zipped.foreach(conf.setConf)
+ val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption)
+ (keys, values).zipped.foreach(sqlContext.conf.setConf)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => conf.setConf(key, value)
- case (key, None) => conf.unsetConf(key)
+ case (key, Some(value)) => sqlContext.conf.setConf(key, value)
+ case (key, None) => sqlContext.conf.unsetConf(key)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 57c23fe77f..b384fb39f3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -52,9 +52,6 @@ case class Contact(name: String, phone: String)
case class Person(name: String, age: Int, contacts: Seq[Contact])
class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
- override val sqlContext = TestHive
-
- import TestHive.read
def getTempFilePath(prefix: String, suffix: String = ""): File = {
val tempFile = File.createTempFile(prefix, suffix)
@@ -69,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withOrcFile(data) { file =>
checkAnswer(
- read.format("orc").load(file),
+ sqlContext.read.format("orc").load(file),
data.toDF().collect())
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index 750f0b04aa..5daf691aa8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -22,13 +22,11 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql._
private[sql] trait OrcTest extends SQLTestUtils {
- protected def hiveContext = sqlContext.asInstanceOf[HiveContext]
+ lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
import sqlContext.sparkContext
import sqlContext.implicits._
@@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path)))
+ withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path)))
}
/**
@@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
- hiveContext.registerDataFrameAsTable(df, tableName)
+ sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}