aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-09-04 15:17:37 -0700
committerAndrew Or <andrew@databricks.com>2015-09-04 15:17:37 -0700
commitc3c0e431a6280fbcf726ac9bc4db0e1b5a862be8 (patch)
tree455aba97f1c8abba6a6076be2d6bbc8a535246c1 /sql
parent804a0126e0cc982cc9f22cc76ba7b874ebbef5dd (diff)
downloadspark-c3c0e431a6280fbcf726ac9bc4db0e1b5a862be8.tar.gz
spark-c3c0e431a6280fbcf726ac9bc4db0e1b5a862be8.tar.bz2
spark-c3c0e431a6280fbcf726ac9bc4db0e1b5a862be8.zip
[SPARK-10176] [SQL] Show partially analyzed plans when checkAnswer fails to analyze
This PR takes over https://github.com/apache/spark/pull/8389. This PR improves `checkAnswer` to print the partially analyzed plan in addition to the user friendly error message, in order to aid debugging failing tests. In doing so, I ran into a conflict with the various ways that we bring a SQLContext into the tests. Depending on the trait we refer to the current context as `sqlContext`, `_sqlContext`, `ctx` or `hiveContext` with access modifiers `public`, `protected` and `private` depending on the defining class. I propose we refactor as follows: 1. All tests should only refer to a `protected sqlContext` when testing general features, and `protected hiveContext` when it is a method that only exists on a `HiveContext`. 2. All tests should only import `testImplicits._` (i.e., don't import `TestHive.implicits._`) Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8584 from cloud-fan/cleanupTests.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala156
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala42
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala99
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala214
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala42
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala35
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala13
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala24
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala18
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala42
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala19
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala54
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala39
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala17
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala23
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala15
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala28
90 files changed, 908 insertions, 999 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 765c1e2dda..f76a903dcc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util._
* Provides helper methods for comparing plans.
*/
class PlanTest extends SparkFunSuite {
-
/**
* Since attribute references are given globally unique ids during analysis,
* we must normalize them to check if two different queries are identical.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index af7590c3d3..3a3541a842 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -34,7 +34,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def rddIdOf(tableName: String): Int = {
- val executedPlan = ctx.table(tableName).queryExecution.executedPlan
+ val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan
executedPlan.collect {
case InMemoryColumnarTableScan(_, _, relation) =>
relation.cachedColumnBuffers.id
@@ -44,7 +44,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
}
def isMaterialized(rddId: Int): Boolean = {
- ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
+ sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
}
test("withColumn doesn't invalidate cached dataframe") {
@@ -69,41 +69,41 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("cache temp table") {
testData.select('key).registerTempTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0)
- ctx.cacheTable("tempTable")
+ sqlContext.cacheTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- ctx.uncacheTable("tempTable")
+ sqlContext.uncacheTable("tempTable")
}
test("unpersist an uncached table will not raise exception") {
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None == sqlContext.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None == sqlContext.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None == sqlContext.cacheManager.lookupCachedData(testData))
testData.persist()
- assert(None != ctx.cacheManager.lookupCachedData(testData))
+ assert(None != sqlContext.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None == sqlContext.cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
- assert(None == ctx.cacheManager.lookupCachedData(testData))
+ assert(None == sqlContext.cacheManager.lookupCachedData(testData))
}
test("cache table as select") {
sql("CACHE TABLE tempTable AS SELECT key FROM testData")
assertCached(sql("SELECT COUNT(*) FROM tempTable"))
- ctx.uncacheTable("tempTable")
+ sqlContext.uncacheTable("tempTable")
}
test("uncaching temp table") {
testData.select('key).registerTempTable("tempTable1")
testData.select('key).registerTempTable("tempTable2")
- ctx.cacheTable("tempTable1")
+ sqlContext.cacheTable("tempTable1")
assertCached(sql("SELECT COUNT(*) FROM tempTable1"))
assertCached(sql("SELECT COUNT(*) FROM tempTable2"))
// Is this valid?
- ctx.uncacheTable("tempTable2")
+ sqlContext.uncacheTable("tempTable2")
// Should this be cached?
assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0)
@@ -111,103 +111,103 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("too big for memory") {
val data = "*" * 1000
- ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF()
.registerTempTable("bigData")
- ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
- assert(ctx.table("bigData").count() === 200000L)
- ctx.table("bigData").unpersist(blocking = true)
+ sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+ assert(sqlContext.table("bigData").count() === 200000L)
+ sqlContext.table("bigData").unpersist(blocking = true)
}
test("calling .cache() should use in-memory columnar caching") {
- ctx.table("testData").cache()
- assertCached(ctx.table("testData"))
- ctx.table("testData").unpersist(blocking = true)
+ sqlContext.table("testData").cache()
+ assertCached(sqlContext.table("testData"))
+ sqlContext.table("testData").unpersist(blocking = true)
}
test("calling .unpersist() should drop in-memory columnar cache") {
- ctx.table("testData").cache()
- ctx.table("testData").count()
- ctx.table("testData").unpersist(blocking = true)
- assertCached(ctx.table("testData"), 0)
+ sqlContext.table("testData").cache()
+ sqlContext.table("testData").count()
+ sqlContext.table("testData").unpersist(blocking = true)
+ assertCached(sqlContext.table("testData"), 0)
}
test("isCached") {
- ctx.cacheTable("testData")
+ sqlContext.cacheTable("testData")
- assertCached(ctx.table("testData"))
- assert(ctx.table("testData").queryExecution.withCachedData match {
+ assertCached(sqlContext.table("testData"))
+ assert(sqlContext.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => true
case _ => false
})
- ctx.uncacheTable("testData")
- assert(!ctx.isCached("testData"))
- assert(ctx.table("testData").queryExecution.withCachedData match {
+ sqlContext.uncacheTable("testData")
+ assert(!sqlContext.isCached("testData"))
+ assert(sqlContext.table("testData").queryExecution.withCachedData match {
case _: InMemoryRelation => false
case _ => true
})
}
test("SPARK-1669: cacheTable should be idempotent") {
- assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
+ assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
- ctx.cacheTable("testData")
- assertCached(ctx.table("testData"))
+ sqlContext.cacheTable("testData")
+ assertCached(sqlContext.table("testData"))
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
- ctx.table("testData").queryExecution.withCachedData.collect {
+ sqlContext.table("testData").queryExecution.withCachedData.collect {
case r: InMemoryRelation => r
}.size
}
- ctx.cacheTable("testData")
+ sqlContext.cacheTable("testData")
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
- ctx.table("testData").queryExecution.withCachedData.collect {
+ sqlContext.table("testData").queryExecution.withCachedData.collect {
case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r
}.size
}
- ctx.uncacheTable("testData")
+ sqlContext.uncacheTable("testData")
}
test("read from cached table and uncache") {
- ctx.cacheTable("testData")
- checkAnswer(ctx.table("testData"), testData.collect().toSeq)
- assertCached(ctx.table("testData"))
+ sqlContext.cacheTable("testData")
+ checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
+ assertCached(sqlContext.table("testData"))
- ctx.uncacheTable("testData")
- checkAnswer(ctx.table("testData"), testData.collect().toSeq)
- assertCached(ctx.table("testData"), 0)
+ sqlContext.uncacheTable("testData")
+ checkAnswer(sqlContext.table("testData"), testData.collect().toSeq)
+ assertCached(sqlContext.table("testData"), 0)
}
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- ctx.uncacheTable("testData")
+ sqlContext.uncacheTable("testData")
}
}
test("SELECT star from cached table") {
sql("SELECT * FROM testData").registerTempTable("selectStar")
- ctx.cacheTable("selectStar")
+ sqlContext.cacheTable("selectStar")
checkAnswer(
sql("SELECT * FROM selectStar WHERE key = 1"),
Seq(Row(1, "1")))
- ctx.uncacheTable("selectStar")
+ sqlContext.uncacheTable("selectStar")
}
test("Self-join cached") {
val unCachedAnswer =
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
- ctx.cacheTable("testData")
+ sqlContext.cacheTable("testData")
checkAnswer(
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
- ctx.uncacheTable("testData")
+ sqlContext.uncacheTable("testData")
}
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
sql("CACHE TABLE testData")
- assertCached(ctx.table("testData"))
+ assertCached(sqlContext.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -215,7 +215,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
"Eagerly cached in-memory table should have already been materialized")
sql("UNCACHE TABLE testData")
- assert(!ctx.isCached("testData"), "Table 'testData' should not be cached")
+ assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
@@ -224,14 +224,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- assertCached(ctx.table("testCacheTable"))
+ assertCached(sqlContext.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- ctx.uncacheTable("testCacheTable")
+ sqlContext.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -239,14 +239,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("CACHE TABLE tableName AS SELECT ...") {
sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10")
- assertCached(ctx.table("testCacheTable"))
+ assertCached(sqlContext.table("testCacheTable"))
val rddId = rddIdOf("testCacheTable")
assert(
isMaterialized(rddId),
"Eagerly cached in-memory table should have already been materialized")
- ctx.uncacheTable("testCacheTable")
+ sqlContext.uncacheTable("testCacheTable")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -254,7 +254,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("CACHE LAZY TABLE tableName") {
sql("CACHE LAZY TABLE testData")
- assertCached(ctx.table("testData"))
+ assertCached(sqlContext.table("testData"))
val rddId = rddIdOf("testData")
assert(
@@ -266,7 +266,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
isMaterialized(rddId),
"Lazily cached in-memory table should have been materialized")
- ctx.uncacheTable("testData")
+ sqlContext.uncacheTable("testData")
eventually(timeout(10 seconds)) {
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
@@ -274,7 +274,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("InMemoryRelation statistics") {
sql("CACHE TABLE testData")
- ctx.table("testData").queryExecution.withCachedData.collect {
+ sqlContext.table("testData").queryExecution.withCachedData.collect {
case cached: InMemoryRelation =>
val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum
assert(cached.statistics.sizeInBytes === actualSizeInBytes)
@@ -283,46 +283,48 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
test("Drops temporary table") {
testData.select('key).registerTempTable("t1")
- ctx.table("t1")
- ctx.dropTempTable("t1")
- assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found"))
+ sqlContext.table("t1")
+ sqlContext.dropTempTable("t1")
+ assert(
+ intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found"))
}
test("Drops cached temporary table") {
testData.select('key).registerTempTable("t1")
testData.select('key).registerTempTable("t2")
- ctx.cacheTable("t1")
+ sqlContext.cacheTable("t1")
- assert(ctx.isCached("t1"))
- assert(ctx.isCached("t2"))
+ assert(sqlContext.isCached("t1"))
+ assert(sqlContext.isCached("t2"))
- ctx.dropTempTable("t1")
- assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found"))
- assert(!ctx.isCached("t2"))
+ sqlContext.dropTempTable("t1")
+ assert(
+ intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found"))
+ assert(!sqlContext.isCached("t2"))
}
test("Clear all cache") {
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- ctx.cacheTable("t1")
- ctx.cacheTable("t2")
- ctx.clearCache()
- assert(ctx.cacheManager.isEmpty)
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+ sqlContext.clearCache()
+ assert(sqlContext.cacheManager.isEmpty)
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- ctx.cacheTable("t1")
- ctx.cacheTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
sql("Clear CACHE")
- assert(ctx.cacheManager.isEmpty)
+ assert(sqlContext.cacheManager.isEmpty)
}
test("Clear accumulators when uncacheTable to prevent memory leaking") {
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- ctx.cacheTable("t1")
- ctx.cacheTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
@@ -331,8 +333,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
Accumulators.synchronized {
val accsSize = Accumulators.originals.size
- ctx.uncacheTable("t1")
- ctx.uncacheTable("t2")
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
assert((accsSize - 2) == Accumulators.originals.size)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 37738ec5b3..4e988f074b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val booleanData = {
- ctx.createDataFrame(ctx.sparkContext.parallelize(
+ sqlContext.createDataFrame(sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
@@ -286,7 +286,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("isNaN") {
- val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
+ val testData = sqlContext.createDataFrame(sparkContext.parallelize(
Row(Double.NaN, Float.NaN) ::
Row(math.log(-1), math.log(-3).toFloat) ::
Row(null, null) ::
@@ -307,7 +307,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("nanvl") {
- val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
+ val testData = sqlContext.createDataFrame(sparkContext.parallelize(
Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
StructField("c", DoubleType), StructField("d", DoubleType),
@@ -350,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("!==") {
- val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize(
+ val nullData = sqlContext.createDataFrame(sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -411,7 +411,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
test("between") {
- val testData = ctx.sparkContext.parallelize(
+ val testData = sparkContext.parallelize(
(0, 1, 2) ::
(1, 2, 3) ::
(2, 1, 0) ::
@@ -556,7 +556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
test("monotonicallyIncreasingId") {
// Make sure we have 2 partitions, each with 2 records.
- val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ =>
+ val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ =>
Iterator(Tuple1(1), Tuple1(2))
}.toDF("a")
checkAnswer(
@@ -567,7 +567,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
test("sparkPartitionId") {
// Make sure we have 2 partitions, each with 2 records.
- val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ =>
+ val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ =>
Iterator(Tuple1(1), Tuple1(2))
}.toDF("a")
checkAnswer(
@@ -578,7 +578,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
test("InputFileName") {
withTempPath { dir =>
- val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id")
+ val data = sparkContext.parallelize(0 to 10).toDF("id")
data.write.parquet(dir.getCanonicalPath)
val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName())
.head.getString(0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 72cf7aab0b..c0950b09b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -66,12 +66,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
- ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false)
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false)
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
- ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true)
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true)
}
test("agg without groups") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index 3c359dd840..09f7b50767 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -28,19 +28,19 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
test("UDF on struct") {
val f = udf((a: String) => a)
- val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(struct($"a").as("s")).select(f($"s.a")).collect()
}
test("UDF on named_struct") {
val f = udf((a: String) => a)
- val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
}
test("UDF on array") {
val f = udf((a: String) => a)
- val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index e5d7d63441..094efbaead 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -24,7 +24,7 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
test("RDD of tuples") {
checkAnswer(
- ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
+ sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
@@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
test("RDD[Int]") {
checkAnswer(
- ctx.sparkContext.parallelize(1 to 10).toDF("intCol"),
+ sparkContext.parallelize(1 to 10).toDF("intCol"),
(1 to 10).map(i => Row(i)))
}
test("RDD[Long]") {
checkAnswer(
- ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"),
+ sparkContext.parallelize(1L to 10L).toDF("longCol"),
(1L to 10L).map(i => Row(i)))
}
test("RDD[String]") {
checkAnswer(
- ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
+ sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
(1 to 10).map(i => Row(i.toString)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 28bdd6f83b..6524abcf5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -29,7 +29,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
test("sample with replacement") {
val n = 100
- val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
Seq(5, 10, 52, 73).map(Row(_))
@@ -38,7 +38,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
test("sample without replacement") {
val n = 100
- val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
Seq(16, 23, 88, 100).map(Row(_))
@@ -47,7 +47,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
test("randomSplit") {
val n = 600
- val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@@ -164,7 +164,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("Frequent Items 2") {
- val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
+ val rows = sparkContext.parallelize(Seq.empty[Int], 4)
// this is a regression test, where when merging partitions, we omitted values with higher
// counts than those that existed in the map when the map was full. This test should also fail
// if anything like SPARK-9614 is observed once again
@@ -182,7 +182,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("sampleBy") {
- val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
+ val df = sqlContext.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a4871e247c..b5b9f11785 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -345,7 +345,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("replace column using withColumn") {
- val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
+ val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
checkAnswer(
df3.select("x"),
@@ -506,7 +506,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("showString: truncate = [true, false]") {
val longString = Array.fill(21)("1").mkString
- val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF()
+ val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = """+---------------------+
||_1 |
|+---------------------+
@@ -596,7 +596,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
- val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
+ val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
val df = sqlContext.createDataFrame(rowRDD, schema)
df.rdd.collect()
@@ -619,14 +619,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
- val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)
- val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ val df2 = sqlContext.read.json(sparkContext.makeRDD(
"""{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
@@ -646,7 +646,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-7324 dropDuplicates") {
- val testData = sqlContext.sparkContext.parallelize(
+ val testData = sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
@@ -869,7 +869,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
- val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index 77907e9136..7ae12a7895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -32,7 +32,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
test("test simple types") {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
- val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
+ val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
}
}
@@ -40,7 +40,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
test("test struct type") {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
val struct = Row(1, 2L, 3.0F, 3.0)
- val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct)))
+ val data = sparkContext.parallelize(Seq(Row(1, struct)))
val schema = new StructType()
.add("a", IntegerType)
@@ -60,7 +60,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
val innerStruct = Row(1, "abcd")
val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
- val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct)))
+ val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))
val schema = new StructType()
.add("a", IntegerType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 8d2f45d703..78a98798ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -52,7 +52,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
try {
sqlContext.experimental.extraStrategies = TestStrategy :: Nil
- val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
+ val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
checkAnswer(
df.select("a"),
Row("so fast"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index f5c5046a8e..b05435bad5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -31,7 +31,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
- val planned = ctx.planner.EquiJoinSelection(join)
+ val planned = sqlContext.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
@@ -59,7 +59,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("join operator selection") {
- ctx.cacheManager.clearCache()
+ sqlContext.cacheManager.clearCache()
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
@@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("broadcasted hash join operator selection") {
- ctx.cacheManager.clearCache()
+ sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
for (sortMergeJoinEnabled <- Seq(true, false)) {
withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
@@ -138,7 +138,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("broadcasted hash outer join operator selection") {
- ctx.cacheManager.clearCache()
+ sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
@@ -167,7 +167,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val x = testData2.as("x")
val y = testData2.as("y")
val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
- val planned = ctx.planner.EquiJoinSelection(join)
+ val planned = sqlContext.planner.EquiJoinSelection(join)
assert(planned.size === 1)
}
@@ -442,7 +442,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("broadcasted left semi join operator selection") {
- ctx.cacheManager.clearCache()
+ sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index babf8835d2..eab0fbb196 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -32,33 +32,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
}
after {
- ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
}
test("get all tables") {
checkAnswer(
- ctx.tables().filter("tableName = 'ListTablesSuiteTable'"),
+ sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
- assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(
- ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+ sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
- assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@@ -66,7 +66,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
+ Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
@@ -77,9 +77,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
Row(true, "ListTablesSuiteTable")
)
checkAnswer(
- ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
- ctx.dropTempTable("tables")
+ sqlContext.dropTempTable("tables")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 3649c2a97b..cada03e9ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-class QueryTest extends PlanTest {
+abstract class QueryTest extends PlanTest {
+
+ protected def sqlContext: SQLContext
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
@@ -56,18 +58,33 @@ class QueryTest extends PlanTest {
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
- protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
- QueryTest.checkAnswer(df, expectedAnswer) match {
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ val analyzedDF = try df catch {
+ case ae: AnalysisException =>
+ val currentValue = sqlContext.conf.dataFrameEagerAnalysis
+ sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
+ val partiallyAnalzyedPlan = df.queryExecution.analyzed
+ sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue)
+ fail(
+ s"""
+ |Failed to analyze query: $ae
+ |$partiallyAnalzyedPlan
+ |
+ |${stackTraceToString(ae)}
+ |""".stripMargin)
+ }
+
+ QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
- protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(df, Seq(expectedAnswer))
}
- protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+ protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
checkAnswer(df, expectedAnswer.collect())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 77ccd6f775..3ba14d7602 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -57,7 +57,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext {
test("serialize w/ kryo") {
val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first()
- val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf)
+ val serializer = new SparkSqlSerializer(sparkContext.getConf)
val instance = serializer.newInstance()
val ser = instance.serialize(row)
val de = instance.deserialize(ser).asInstanceOf[Row]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 7699adadd9..c35b31c96d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -27,58 +27,58 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
test("propagate from spark conf") {
// We create a new context here to avoid order dependence with other tests that might call
// clear().
- val newContext = new SQLContext(ctx.sparkContext)
+ val newContext = new SQLContext(sparkContext)
assert(newContext.getConf("spark.sql.testkey", "false") === "true")
}
test("programmatic ways of basic setting and getting") {
- ctx.conf.clear()
- assert(ctx.getAllConfs.size === 0)
+ sqlContext.conf.clear()
+ assert(sqlContext.getAllConfs.size === 0)
- ctx.setConf(testKey, testVal)
- assert(ctx.getConf(testKey) === testVal)
- assert(ctx.getConf(testKey, testVal + "_") === testVal)
- assert(ctx.getAllConfs.contains(testKey))
+ sqlContext.setConf(testKey, testVal)
+ assert(sqlContext.getConf(testKey) === testVal)
+ assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
+ assert(sqlContext.getAllConfs.contains(testKey))
// Tests SQLConf as accessed from a SQLContext is mutable after
// the latter is initialized, unlike SparkConf inside a SparkContext.
- assert(ctx.getConf(testKey) == testVal)
- assert(ctx.getConf(testKey, testVal + "_") === testVal)
- assert(ctx.getAllConfs.contains(testKey))
+ assert(sqlContext.getConf(testKey) == testVal)
+ assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
+ assert(sqlContext.getAllConfs.contains(testKey))
- ctx.conf.clear()
+ sqlContext.conf.clear()
}
test("parse SQL set commands") {
- ctx.conf.clear()
+ sqlContext.conf.clear()
sql(s"set $testKey=$testVal")
- assert(ctx.getConf(testKey, testVal + "_") === testVal)
- assert(ctx.getConf(testKey, testVal + "_") === testVal)
+ assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
+ assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
sql("set some.property=20")
- assert(ctx.getConf("some.property", "0") === "20")
+ assert(sqlContext.getConf("some.property", "0") === "20")
sql("set some.property = 40")
- assert(ctx.getConf("some.property", "0") === "40")
+ assert(sqlContext.getConf("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
sql(s"set $key=$vs")
- assert(ctx.getConf(key, "0") === vs)
+ assert(sqlContext.getConf(key, "0") === vs)
sql(s"set $key=")
- assert(ctx.getConf(key, "0") === "")
+ assert(sqlContext.getConf(key, "0") === "")
- ctx.conf.clear()
+ sqlContext.conf.clear()
}
test("deprecated property") {
- ctx.conf.clear()
+ sqlContext.conf.clear()
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
- assert(ctx.conf.numShufflePartitions === 10)
+ assert(sqlContext.conf.numShufflePartitions === 10)
}
test("invalid conf value") {
- ctx.conf.clear()
+ sqlContext.conf.clear()
val e = intercept[IllegalArgumentException] {
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index 007be12950..dd88ae3700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -24,7 +24,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
override def afterAll(): Unit = {
try {
- SQLContext.setLastInstantiatedContext(ctx)
+ SQLContext.setLastInstantiatedContext(sqlContext)
} finally {
super.afterAll()
}
@@ -32,18 +32,18 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
test("getOrCreate instantiates SQLContext") {
SQLContext.clearLastInstantiatedContext()
- val sqlContext = SQLContext.getOrCreate(ctx.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(sparkContext)
assert(sqlContext != null, "SQLContext.getOrCreate returned null")
- assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
"SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
}
test("getOrCreate gets last explicitly instantiated SQLContext") {
SQLContext.clearLastInstantiatedContext()
- val sqlContext = new SQLContext(ctx.sparkContext)
- assert(SQLContext.getOrCreate(ctx.sparkContext) != null,
+ val sqlContext = new SQLContext(sparkContext)
+ assert(SQLContext.getOrCreate(sparkContext) != null,
"SQLContext.getOrCreate after explicitly created SQLContext returned null")
- assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
+ assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 0ef25fe0fa..05f2000459 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -147,14 +147,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SQL Dialect Switching to a new SQL parser") {
- val newContext = new SQLContext(sqlContext.sparkContext)
+ val newContext = new SQLContext(sparkContext)
newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(newContext.getSQLDialect().getClass === classOf[MyDialect])
assert(newContext.sql("SELECT 1").collect() === Array(Row(1)))
}
test("SQL Dialect Switch to an invalid parser with alias") {
- val newContext = new SQLContext(sqlContext.sparkContext)
+ val newContext = new SQLContext(sparkContext)
newContext.sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
newContext.sql("SELECT 1")
@@ -196,7 +196,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("grouping on nested fields") {
- sqlContext.read.json(sqlContext.sparkContext.parallelize(
+ sqlContext.read.json(sparkContext.parallelize(
"""{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
.registerTempTable("rows")
@@ -215,7 +215,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-6201 IN type conversion") {
sqlContext.read.json(
- sqlContext.sparkContext.parallelize(
+ sparkContext.parallelize(
Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}")))
.registerTempTable("d")
@@ -1342,7 +1342,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-3483 Special chars in column names") {
- val data = sqlContext.sparkContext.parallelize(
+ val data = sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
sqlContext.read.json(data).registerTempTable("records")
sql("SELECT `key?number1`, `key.number2` FROM records")
@@ -1385,13 +1385,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-4322 Grouping field with struct field as sub expression") {
- sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil))
+ sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil))
.registerTempTable("data")
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1))
sqlContext.dropTempTable("data")
sqlContext.read.json(
- sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
+ sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2))
sqlContext.dropTempTable("data")
}
@@ -1412,10 +1412,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil
- val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
+ val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
rdd1.toDF().registerTempTable("nulldata1")
val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil
- val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
+ val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
rdd2.toDF().registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
@@ -1424,7 +1424,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i)))
+ val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
@@ -1432,14 +1432,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-4699 case sensitivity SQL query") {
sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i)))
+ val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("testTable1")
checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
sqlContext.setConf(SQLConf.CASE_SENSITIVE, true)
}
test("SPARK-6145: ORDER BY test for nested fields") {
- sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil))
.registerTempTable("nestedOrder")
@@ -1452,14 +1452,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-6145: special cases") {
- sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
- sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
.registerTempTable("t")
@@ -1543,7 +1543,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-7067: order by queries for complex ExtractValue chain") {
withTempTable("t") {
- sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
}
@@ -1610,8 +1610,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("aggregation with codegen updates peak execution memory") {
withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) {
- val sc = sqlContext.sparkContext
- AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") {
testCodeGen(
"SELECT key, count(value) FROM testData GROUP BY key",
(1 to 100).map(i => Row(i, 1)))
@@ -1670,8 +1669,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("external sorting updates peak execution memory") {
withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
- val sc = sqlContext.sparkContext
- AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") {
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sortTest()
}
}
@@ -1679,7 +1677,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-9511: error with table starting with number") {
withTempTable("1one") {
- sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
+ sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
.toDF("num", "str")
.registerTempTable("1one")
checkAnswer(sql("select count(num) from 1one"), Row(10))
@@ -1690,7 +1688,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
withTempPath { dir =>
val path = dir.getCanonicalPath
val df =
- sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
+ sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
df
.write
.format("parquet")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index 45d0ee4a8e..ddab918629 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext
class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
- val _sqlContext = new SQLContext(sqlContext.sparkContext)
+ val _sqlContext = new SQLContext(sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index b91438baea..e12e6bea30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -268,9 +268,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row(3, 4))
intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("length(c)"), // int type of the argument is unacceptable
- Row("5.0000"))
+ df.selectExpr("length(c)") // int type of the argument is unacceptable
}
}
@@ -284,63 +282,46 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("number format function") {
- val tuple =
- ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
- 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
- val df =
- Seq(tuple)
- .toDF(
- "a", // string "aa"
- "b", // byte 1
- "c", // short 2
- "d", // float 3.13223f
- "e", // integer 4
- "f", // long 5L
- "g", // double 6.48173d
- "h") // decimal 7.128381
-
- checkAnswer(
- df.select(format_number($"f", 4)),
+ val df = sqlContext.range(1)
+
+ checkAnswer(
+ df.select(format_number(lit(5L), 4)),
Row("5.0000"))
checkAnswer(
- df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
+ df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer
Row("1.0000"))
checkAnswer(
- df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
+ df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer
Row("2.0000"))
checkAnswer(
- df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
+ df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double
Row("3.1322"))
checkAnswer(
- df.selectExpr("format_number(e, e)"), // not convert anything
+ df.select(format_number(lit(4), 4)), // not convert anything
Row("4.0000"))
checkAnswer(
- df.selectExpr("format_number(f, e)"), // not convert anything
+ df.select(format_number(lit(5L), 4)), // not convert anything
Row("5.0000"))
checkAnswer(
- df.selectExpr("format_number(g, e)"), // not convert anything
+ df.select(format_number(lit(6.48173), 4)), // not convert anything
Row("6.4817"))
checkAnswer(
- df.selectExpr("format_number(h, e)"), // not convert anything
+ df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything
Row("7.1284"))
intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
- Row("5.0000"))
+ df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable
}
intercept[AnalysisException] {
- checkAnswer(
- df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
- Row("5.0000"))
+ df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable
}
// for testing the mutable state of the expression in code gen.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index eb275af101..e0435a0dba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("built-in fixed arity expressions") {
- val df = ctx.emptyDataFrame
+ val df = sqlContext.emptyDataFrame
df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)")
}
@@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
- ctx.dropTempTable("tmp_table")
+ sqlContext.dropTempTable("tmp_table")
}
test("SPARK-8005 input_file_name") {
withTempPath { dir =>
- val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
+ val data = sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
- ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
+ sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
val answer = sql("select input_file_name() from test_table").head().getString(0)
assert(answer.contains(dir.getCanonicalPath))
assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
- ctx.dropTempTable("test_table")
+ sqlContext.dropTempTable("test_table")
}
}
test("error reporting for incorrect number of arguments") {
- val df = ctx.emptyDataFrame
+ val df = sqlContext.emptyDataFrame
val e = intercept[AnalysisException] {
df.selectExpr("substr('abcd', 2, 3, 4)")
}
@@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("error reporting for undefined functions") {
- val df = ctx.emptyDataFrame
+ val df = sqlContext.emptyDataFrame
val e = intercept[AnalysisException] {
df.selectExpr("a_function_that_does_not_exist()")
}
@@ -87,24 +87,24 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("Simple UDF") {
- ctx.udf.register("strLenScala", (_: String).length)
+ sqlContext.udf.register("strLenScala", (_: String).length)
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
- ctx.udf.register("random0", () => { Math.random()})
+ sqlContext.udf.register("random0", () => { Math.random()})
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
- ctx.udf.register("strLenScala", (_: String).length + (_: Int))
+ sqlContext.udf.register("strLenScala", (_: String).length + (_: Int))
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("UDF in a WHERE") {
- ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 })
+ sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 })
- val df = ctx.sparkContext.parallelize(
+ val df = sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("integerData")
@@ -114,7 +114,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDF in a HAVING") {
- ctx.udf.register("havingFilter", (n: Long) => { n > 5 })
+ sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -133,7 +133,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDF in a GROUP BY") {
- ctx.udf.register("groupFunction", (n: Int) => { n > 10 })
+ sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -150,10 +150,10 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("UDFs everywhere") {
- ctx.udf.register("groupFunction", (n: Int) => { n > 10 })
- ctx.udf.register("havingFilter", (n: Long) => { n > 2000 })
- ctx.udf.register("whereFilter", (n: Int) => { n < 150 })
- ctx.udf.register("timesHundred", (n: Long) => { n * 100 })
+ sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 })
+ sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 })
+ sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 })
+ sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 })
val df = Seq(("red", 1), ("red", 2), ("blue", 10),
("green", 100), ("green", 200)).toDF("g", "v")
@@ -172,7 +172,7 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("struct UDF") {
- ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
+ sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
sql("SELECT returnStruct('test', 'test2') as ret")
@@ -181,13 +181,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("udf that is transformed") {
- ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
test("type coercion for udf inputs") {
- ctx.udf.register("intExpected", (x: Int) => x)
+ sqlContext.udf.register("intExpected", (x: Int) => x)
// pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index b6d279ae47..fa8f9c8e00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -90,7 +90,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
}
test("UDTs and UDFs") {
- ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
+ sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
sql("SELECT testType(features) from points"),
@@ -148,8 +148,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
StructField("vec", new MyDenseVectorUDT, false)
))
- val stringRDD = ctx.sparkContext.parallelize(data)
- val jsonRDD = ctx.read.schema(schema).json(stringRDD)
+ val stringRDD = sparkContext.parallelize(data)
+ val jsonRDD = sqlContext.read.schema(schema).json(stringRDD)
checkAnswer(
jsonRDD,
Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 83db9b6510..cd3644eb9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
setupTestData()
test("simple columnar query") {
- val plan = ctx.executePlan(testData.logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -39,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
- ctx.cacheTable("sizeTst")
+ sqlContext.cacheTable("sizeTst")
assert(
- ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
- ctx.conf.autoBroadcastJoinThreshold)
+ sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
+ sqlContext.conf.autoBroadcastJoinThreshold)
}
test("projection") {
- val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
@@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
- val plan = ctx.executePlan(testData.logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -69,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
- ctx.cacheTable("repeatedData")
+ sqlContext.cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
@@ -81,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
- ctx.cacheTable("nullableRepeatedData")
+ sqlContext.cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
@@ -96,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq)
- ctx.cacheTable("timestamps")
+ sqlContext.cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
@@ -108,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
- ctx.cacheTable("withEmptyParts")
+ sqlContext.cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
@@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Create a RDD for the schema
val rdd =
- ctx.sparkContext.parallelize((1 to 100), 10).map { i =>
+ sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
@@ -177,24 +177,24 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
- ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
+ sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
- val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan
+ sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
- ctx.isCached("InMemoryCache_different_data_types"),
+ sqlContext.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
- ctx.table("InMemoryCache_different_data_types").collect())
- ctx.dropTempTable("InMemoryCache_different_data_types")
+ sqlContext.table("InMemoryCache_different_data_types").collect())
+ sqlContext.dropTempTable("InMemoryCache_different_data_types")
}
test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") {
- val df =
- ctx.range(1, 100).selectExpr("id % 10 as id").rdd.map(id => Tuple1(s"str_$id")).toDF("i")
+ val df = sqlContext.range(1, 100).selectExpr("id % 10 as id")
+ .rdd.map(id => Tuple1(s"str_$id")).toDF("i")
val cached = df.cache()
// count triggers the caching action. It should not throw.
cached.count()
@@ -205,7 +205,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Check result.
checkAnswer(
cached,
- ctx.range(1, 100).selectExpr("id % 10 as id").rdd.map(id => Tuple1(s"str_$id")).toDF("i")
+ sqlContext.range(1, 100).selectExpr("id % 10 as id")
+ .rdd.map(id => Tuple1(s"str_$id")).toDF("i")
)
// Drop the cache.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index ab2644eb45..6b7401464f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -25,32 +25,32 @@ import org.apache.spark.sql.test.SQLTestData._
class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
- private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
- private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
+ private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize
+ private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
- ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
+ sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
- val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
+ val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
- ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
+ sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
- ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
- ctx.cacheTable("pruningData")
+ sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
+ sqlContext.cacheTable("pruningData")
}
override protected def afterAll(): Unit = {
try {
- ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
- ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
- ctx.uncacheTable("pruningData")
+ sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+ sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+ sqlContext.uncacheTable("pruningData")
} finally {
super.afterAll()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 8998f51111..911d12e93e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
+
test("shuffling UnsafeRows in exchange") {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index fad93b014c..cafa1d5154 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution
-import org.apache.spark.SparkFunSuite
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.InternalRow
@@ -31,14 +30,14 @@ import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-class PlannerSuite extends SparkFunSuite with SharedSQLContext {
+class PlannerSuite extends SharedSQLContext {
import testImplicits._
setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
- val _ctx = ctx
- import _ctx.planner._
+ val planner = sqlContext.planner
+ import planner._
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
plannedOption.getOrElse(
@@ -53,8 +52,8 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext {
}
test("unions are collapsed") {
- val _ctx = ctx
- import _ctx.planner._
+ val planner = sqlContext.planner
+ import planner._
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
val logicalUnions = query collect { case u: logical.Union => u }
@@ -81,33 +80,30 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext {
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
- def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
- val fields = fieldTypes.zipWithIndex.map {
- case (dataType, index) => StructField(s"c${index}", dataType, true)
- } :+ StructField("key", IntegerType, true)
- val schema = StructType(fields)
- val row = Row.fromSeq(Seq.fill(fields.size)(null))
- val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
- ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
-
- val planned = sql(
- """
- |SELECT l.a, l.b
- |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
- """.stripMargin).queryExecution.executedPlan
-
- val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
- val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
-
- assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
- assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
-
- ctx.dropTempTable("testLimit")
+ def checkPlan(fieldTypes: Seq[DataType]): Unit = {
+ withTempTable("testLimit") {
+ val fields = fieldTypes.zipWithIndex.map {
+ case (dataType, index) => StructField(s"c${index}", dataType, true)
+ } :+ StructField("key", IntegerType, true)
+ val schema = StructType(fields)
+ val row = Row.fromSeq(Seq.fill(fields.size)(null))
+ val rowRDD = sparkContext.parallelize(row :: Nil)
+ sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
+
+ val planned = sql(
+ """
+ |SELECT l.a, l.b
+ |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
+ """.stripMargin).queryExecution.executedPlan
+
+ val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+ val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+
+ assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+ assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+ }
}
- val origThreshold = ctx.conf.autoBroadcastJoinThreshold
-
val simpleTypes =
NullType ::
BooleanType ::
@@ -124,7 +120,9 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext {
StringType ::
BinaryType :: Nil
- checkPlan(simpleTypes, newThreshold = 16434)
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") {
+ checkPlan(simpleTypes)
+ }
val complexTypes =
ArrayType(DoubleType, true) ::
@@ -136,36 +134,37 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext {
StructField("b", ArrayType(DoubleType), nullable = false),
StructField("c", DoubleType, nullable = false))) :: Nil
- checkPlan(complexTypes, newThreshold = 901617)
-
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") {
+ checkPlan(complexTypes)
+ }
}
test("InMemoryRelation statistics propagation") {
- val origThreshold = ctx.conf.autoBroadcastJoinThreshold
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
-
- testData.limit(3).registerTempTable("tiny")
- sql("CACHE TABLE tiny")
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") {
+ withTempTable("tiny") {
+ testData.limit(3).registerTempTable("tiny")
+ sql("CACHE TABLE tiny")
- val a = testData.as("a")
- val b = ctx.table("tiny").as("b")
- val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
+ val a = testData.as("a")
+ val b = sqlContext.table("tiny").as("b")
+ val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
- val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
- val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+ val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+ val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
- assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
- assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+ assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+ assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
- ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+ sqlContext.clearCache()
+ }
+ }
}
test("efficient limit -> project -> sort") {
{
val query =
testData.select('key, 'value).sort('key).limit(2).logicalPlan
- val planned = ctx.planner.TakeOrderedAndProject(query)
+ val planned = sqlContext.planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
assert(planned.head.output === testData.select('key, 'value).logicalPlan.output)
}
@@ -175,7 +174,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext {
// into it.
val query =
testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan
- val planned = ctx.planner.TakeOrderedAndProject(query)
+ val planned = sqlContext.planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
assert(planned.head.output === testData.select('value, 'key).logicalPlan.output)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index ef6ad59b71..4492e37ad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
test("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
}
test("filter can process unsafe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
}
@@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
test("union requires all of its input rows' formats to agree") {
val plan = Union(Seq(outputsSafe, outputsUnsafe))
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("union can process safe rows") {
val plan = Union(Seq(outputsSafe, outputsSafe))
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(!preparedPlan.outputsUnsafeRows)
}
test("union can process unsafe rows") {
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
- val preparedPlan = ctx.prepareForExecution.execute(plan)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("round trip with ConvertToUnsafe and ConvertToSafe") {
val input = Seq(("hello", 1), ("world", 2))
checkAnswer(
- ctx.createDataFrame(input),
+ sqlContext.createDataFrame(input),
plan => ConvertToSafe(ConvertToUnsafe(plan)),
input.map(Row.fromTuple)
)
}
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
- SparkPlan.currentContext.set(ctx)
+ SparkPlan.currentContext.set(sqlContext)
val schema = ArrayType(StringType)
val rows = (1 to 100).map { i =>
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 8fa77b0fcb..3073d492e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
class SortSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
// This test was originally added as an example of how to use [[SparkPlanTest]];
// it's not designed to be a comprehensive test of ExternalSort.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 5ab8f44fae..de45ae4635 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -31,14 +31,7 @@ import org.apache.spark.sql.test.SQLTestUtils
* class's test helper methods can be used, see [[SortSuite]].
*/
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
- protected def _sqlContext: SQLContext
-
- /**
- * Creates a DataFrame from a local Seq of Product.
- */
- implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
- _sqlContext.implicits.localSeqToDataFrameHolder(data)
- }
+ protected def sqlContext: SQLContext
/**
* Runs the plan and makes sure the answer matches the expected result.
@@ -98,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -122,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
- input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
+ input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -149,13 +142,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
- _sqlContext: SQLContext): Option[String] = {
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan, _sqlContext)
+ executePlan(expectedOutputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -170,7 +163,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan, _sqlContext)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -210,12 +203,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
- _sqlContext: SQLContext): Option[String] = {
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan, _sqlContext)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -238,10 +231,10 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
+ private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = _sqlContext.prepareForExecution.execute(
+ val resolvedPlan = sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 3158458edb..7a0f0dfd2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -29,15 +29,16 @@ import org.apache.spark.sql.types._
* A test suite that generates randomized data to test the [[TungstenSort]] operator.
*/
class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
override def beforeAll(): Unit = {
super.beforeAll()
- ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
try {
- ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED)
} finally {
super.afterAll()
}
@@ -64,8 +65,7 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
}
test("sorting updates peak execution memory") {
- val sc = ctx.sparkContext
- AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
(child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
@@ -83,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
) {
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
val inputData = Seq.fill(1000)(randomDataGenerator())
- val inputDf = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ val inputDf = sqlContext.createDataFrame(
+ sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
assert(TungstenSort.supportsSchema(inputDf.schema))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index 5fdb82b067..afda0d29f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -37,7 +37,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
() => new InterpretedMutableProjection(expr, schema)
}
- val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy")
+ val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0,
Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 1174b27732..6a18cc6d27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -215,7 +215,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Complex field and type inferring with null in sampling") {
- val jsonDF = ctx.read.json(jsonNullStruct)
+ val jsonDF = sqlContext.read.json(jsonNullStruct)
val expectedSchema = StructType(
StructField("headers", StructType(
StructField("Charset", StringType, true) ::
@@ -234,7 +234,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Primitive field and type inferring") {
- val jsonDF = ctx.read.json(primitiveFieldAndType)
+ val jsonDF = sqlContext.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -262,7 +262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Complex field and type inferring") {
- val jsonDF = ctx.read.json(complexFieldAndType1)
+ val jsonDF = sqlContext.read.json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
@@ -361,7 +361,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("GetField operation on complex data type") {
- val jsonDF = ctx.read.json(complexFieldAndType1)
+ val jsonDF = sqlContext.read.json(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -377,7 +377,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in primitive field values") {
- val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
@@ -449,7 +449,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
ignore("Type conflict in primitive field values (Ignored)") {
- val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
+ val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expression.
@@ -502,7 +502,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in complex field values") {
- val jsonDF = ctx.read.json(complexFieldValueTypeConflict)
+ val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("array", ArrayType(LongType, true), true) ::
@@ -526,7 +526,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Type conflict in array elements") {
- val jsonDF = ctx.read.json(arrayElementTypeConflict)
+ val jsonDF = sqlContext.read.json(arrayElementTypeConflict)
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
@@ -554,7 +554,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("Handling missing fields") {
- val jsonDF = ctx.read.json(missingFields)
+ val jsonDF = sqlContext.read.json(missingFields)
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
@@ -573,9 +573,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val dir = Utils.createTempDir()
dir.delete()
val path = dir.getCanonicalFile.toURI.toString
- ctx.sparkContext.parallelize(1 to 100)
+ sparkContext.parallelize(1 to 100)
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
- val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path)
+ val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path)
val analyzed = jsonDF.queryExecution.analyzed
assert(
@@ -590,7 +590,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
- ctx.read.schema(schema).json(path)
+ sqlContext.read.schema(schema).json(path)
.queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.paths === Array(path))
@@ -603,7 +603,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
- val jsonDF = ctx.read.json(path)
+ val jsonDF = sqlContext.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -672,7 +672,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
- val jsonDF1 = ctx.read.schema(schema).json(path)
+ val jsonDF1 = sqlContext.read.schema(schema).json(path)
assert(schema === jsonDF1.schema)
@@ -689,7 +689,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"this is a simple string.")
)
- val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType)
+ val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType)
assert(schema === jsonDF2.schema)
@@ -710,7 +710,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("Applying schemas with MapType") {
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
+ val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1)
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
@@ -738,7 +738,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val schemaWithComplexMap = StructType(
StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
- val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2)
+ val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2)
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
@@ -764,7 +764,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-2096 Correctly parse dot notations") {
- val jsonDF = ctx.read.json(complexFieldAndType2)
+ val jsonDF = sqlContext.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -782,7 +782,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-3390 Complex arrays") {
- val jsonDF = ctx.read.json(complexFieldAndType2)
+ val jsonDF = sqlContext.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -805,7 +805,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-3308 Read top level JSON arrays") {
- val jsonDF = ctx.read.json(jsonArray)
+ val jsonDF = sqlContext.read.json(jsonArray)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
@@ -823,64 +823,63 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("Corrupt records") {
// Test if we can query corrupt records.
- val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
- ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
-
- val jsonDF = ctx.read.json(corruptRecords)
- jsonDF.registerTempTable("jsonTable")
-
- val schema = StructType(
- StructField("_unparsed", StringType, true) ::
- StructField("a", StringType, true) ::
- StructField("b", StringType, true) ::
- StructField("c", StringType, true) :: Nil)
-
- assert(schema === jsonDF.schema)
-
- // In HiveContext, backticks should be used to access columns starting with a underscore.
- checkAnswer(
- sql(
- """
- |SELECT a, b, c, _unparsed
- |FROM jsonTable
- """.stripMargin),
- Row(null, null, null, "{") ::
- Row(null, null, null, "") ::
- Row(null, null, null, """{"a":1, b:2}""") ::
- Row(null, null, null, """{"a":{, b:3}""") ::
- Row("str_a_4", "str_b_4", "str_c_4", null) ::
- Row(null, null, null, "]") :: Nil
- )
-
- checkAnswer(
- sql(
- """
- |SELECT a, b, c
- |FROM jsonTable
- |WHERE _unparsed IS NULL
- """.stripMargin),
- Row("str_a_4", "str_b_4", "str_c_4")
- )
-
- checkAnswer(
- sql(
- """
- |SELECT _unparsed
- |FROM jsonTable
- |WHERE _unparsed IS NOT NULL
- """.stripMargin),
- Row("{") ::
- Row("") ::
- Row("""{"a":1, b:2}""") ::
- Row("""{"a":{, b:3}""") ::
- Row("]") :: Nil
- )
-
- ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") {
+ withTempTable("jsonTable") {
+ val jsonDF = sqlContext.read.json(corruptRecords)
+ jsonDF.registerTempTable("jsonTable")
+
+ val schema = StructType(
+ StructField("_unparsed", StringType, true) ::
+ StructField("a", StringType, true) ::
+ StructField("b", StringType, true) ::
+ StructField("c", StringType, true) :: Nil)
+
+ assert(schema === jsonDF.schema)
+
+ // In HiveContext, backticks should be used to access columns starting with a underscore.
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c, _unparsed
+ |FROM jsonTable
+ """.stripMargin),
+ Row(null, null, null, "{") ::
+ Row(null, null, null, "") ::
+ Row(null, null, null, """{"a":1, b:2}""") ::
+ Row(null, null, null, """{"a":{, b:3}""") ::
+ Row("str_a_4", "str_b_4", "str_c_4", null) ::
+ Row(null, null, null, "]") :: Nil
+ )
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c
+ |FROM jsonTable
+ |WHERE _unparsed IS NULL
+ """.stripMargin),
+ Row("str_a_4", "str_b_4", "str_c_4")
+ )
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT _unparsed
+ |FROM jsonTable
+ |WHERE _unparsed IS NOT NULL
+ """.stripMargin),
+ Row("{") ::
+ Row("") ::
+ Row("""{"a":1, b:2}""") ::
+ Row("""{"a":{, b:3}""") ::
+ Row("]") :: Nil
+ )
+ }
+ }
}
test("SPARK-4068: nulls in arrays") {
- val jsonDF = ctx.read.json(nullsInArrays)
+ val jsonDF = sqlContext.read.json(nullsInArrays)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
@@ -926,7 +925,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = ctx.createDataFrame(rowRDD1, schema1)
+ val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDF
val result = df2.toJSON.collect()
@@ -949,7 +948,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = ctx.createDataFrame(rowRDD2, schema2)
+ val df3 = sqlContext.createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDF
val result2 = df4.toJSON.collect()
@@ -957,8 +956,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
- val jsonDF = ctx.read.json(primitiveFieldAndType)
- val primTable = ctx.read.json(jsonDF.toJSON)
+ val jsonDF = sqlContext.read.json(primitiveFieldAndType)
+ val primTable = sqlContext.read.json(jsonDF.toJSON)
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
@@ -970,8 +969,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
"this is a simple string.")
)
- val complexJsonDF = ctx.read.json(complexFieldAndType1)
- val compTable = ctx.read.json(complexJsonDF.toJSON)
+ val complexJsonDF = sqlContext.read.json(complexFieldAndType1)
+ val compTable = sqlContext.read.json(complexJsonDF.toJSON)
compTable.registerTempTable("complexTable")
// Access elements of a primitive array.
checkAnswer(
@@ -1039,25 +1038,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
Some(empty),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(ctx)
+ None, None)(sqlContext)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(ctx)
+ None, None)(sqlContext)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(ctx)
+ None, None)(sqlContext)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
- None, None)(ctx)
+ None, None)(sqlContext)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation0 !== relation1)
@@ -1078,18 +1077,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
withTempPath(dir => {
val path = dir.getCanonicalFile.toURI.toString
- ctx.sparkContext.parallelize(1 to 100)
+ sparkContext.parallelize(1 to 100)
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = ResolvedDataSource(
- ctx,
+ sqlContext,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
val d2 = ResolvedDataSource(
- ctx,
+ sqlContext,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
@@ -1105,24 +1104,21 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-7565 MapType in JsonRDD") {
- val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
- ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
-
- val schemaWithSimpleMap = StructType(
- StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
- try {
- val temp = Utils.createTempDir().getPath
-
- val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
- df.write.mode("overwrite").parquet(temp)
- // order of MapType is not defined
- assert(ctx.read.parquet(temp).count() == 5)
-
- val df2 = ctx.read.json(corruptRecords)
- df2.write.mode("overwrite").parquet(temp)
- checkAnswer(ctx.read.parquet(temp), df2.collect())
- } finally {
- ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
+ withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") {
+ withTempDir { dir =>
+ val schemaWithSimpleMap = StructType(
+ StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
+ val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1)
+
+ val path = dir.getAbsolutePath
+ df.write.mode("overwrite").parquet(path)
+ // order of MapType is not defined
+ assert(sqlContext.read.parquet(path).count() == 5)
+
+ val df2 = sqlContext.read.json(corruptRecords)
+ df2.write.mode("overwrite").parquet(path)
+ checkAnswer(sqlContext.read.parquet(path), df2.collect())
+ }
}
}
@@ -1142,19 +1138,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val d1 = new File(root, "d1=1")
// root/dt=1/col1=abc
val p1_col1 = makePartition(
- ctx.sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""),
+ sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""),
d1,
"col1",
"abc")
// root/dt=1/col1=abd
val p2 = makePartition(
- ctx.sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""),
+ sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""),
d1,
"col1",
"abd")
- ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
+ sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 2864181cf9..713d1da1cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
private[json] trait TestJsonData {
- protected def _sqlContext: SQLContext
+ protected def sqlContext: SQLContext
def primitiveFieldAndType: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@@ -35,7 +35,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def primitiveFieldValueTypeConflict: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -46,14 +46,14 @@ private[json] trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
def jsonNullStruct: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
def complexFieldValueTypeConflict: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@@ -64,14 +64,14 @@ private[json] trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
def arrayElementTypeConflict: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
def missingFields: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
@@ -79,7 +79,7 @@ private[json] trait TestJsonData {
"""{"e":"str"}""" :: Nil)
def complexFieldAndType1: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@@ -95,7 +95,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def complexFieldAndType2: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@@ -149,7 +149,7 @@ private[json] trait TestJsonData {
}""" :: Nil)
def mapType1: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
@@ -157,7 +157,7 @@ private[json] trait TestJsonData {
"""{"map": {"e": null}}""" :: Nil)
def mapType2: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -166,21 +166,21 @@ private[json] trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)
def nullsInArrays: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
def jsonArray: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
def corruptRecords: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@@ -189,7 +189,7 @@ private[json] trait TestJsonData {
"""]""" :: Nil)
def emptyRecords: RDD[String] =
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
@@ -198,7 +198,7 @@ private[json] trait TestJsonData {
"""]""" :: Nil)
- lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
+ lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
- def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
+ def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index 91f3ce4d34..0835bd1230 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -39,12 +39,13 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
val fsPath = new Path(path)
- val fs = fsPath.getFileSystem(configuration)
+ val fs = fsPath.getFileSystem(hadoopConfiguration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
}).toSeq.asJava
- val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
+ val footers =
+ ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true)
footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 08d2b9dee9..cd552e8337 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -101,7 +101,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("fixed-length decimals") {
def makeDecimalRDD(decimal: DecimalType): DataFrame =
- sqlContext.sparkContext
+ sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
.toDF()
@@ -119,7 +119,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("date type") {
def makeDateRDD(): DataFrame =
- sqlContext.sparkContext
+ sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(DateTimeUtils.toJavaDate(i)))
.toDF()
@@ -207,7 +207,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("compression codec") {
def compressionCodecFor(path: String): String = {
val codecs = ParquetTypesConverter
- .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala
+ .readMetaData(new Path(path), Some(hadoopConfiguration)).getBlocks.asScala
.flatMap(_.getColumns.asScala)
.map(_.getCodec.name())
.distinct
@@ -277,14 +277,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("write metadata") {
withTempPath { file =>
val path = new Path(file.toURI.toString)
- val fs = FileSystem.getLocal(configuration)
+ val fs = FileSystem.getLocal(hadoopConfiguration)
val attributes = ScalaReflection.attributesFor[(Int, String)]
- ParquetTypesConverter.writeMetaData(attributes, path, configuration)
+ ParquetTypesConverter.writeMetaData(attributes, path, hadoopConfiguration)
assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)))
assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
- val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration))
+ val metaData = ParquetTypesConverter.readMetaData(path, Some(hadoopConfiguration))
val actualSchema = metaData.getFileMetaData.getSchema
val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes)
@@ -355,7 +355,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
val path = new Path(location.getCanonicalPath)
ParquetFileWriter.writeMetadataFile(
- sqlContext.sparkContext.hadoopConfiguration,
+ sparkContext.hadoopConfiguration,
path,
Collections.singletonList(
new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList()))))
@@ -370,12 +370,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
test("SPARK-6352 DirectParquetOutputCommitter") {
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
// Write to a parquet file and let it fail.
// _temporary should be missing if direct output committer works.
try {
- configuration.set("spark.sql.parquet.output.committer.class",
+ hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
classOf[DirectParquetOutputCommitter].getCanonicalName)
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
@@ -383,23 +383,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
- val fs = path.getFileSystem(configuration)
+ val fs = path.getFileSystem(hadoopConfiguration)
assert(!fs.exists(path))
}
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
}
}
test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") {
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
// Write to a parquet file and let it fail.
// _temporary should be missing if direct output committer works.
try {
- configuration.set("spark.sql.parquet.output.committer.class",
+ hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
"org.apache.spark.sql.parquet.DirectParquetOutputCommitter")
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
@@ -407,25 +407,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
- val fs = path.getFileSystem(configuration)
+ val fs = path.getFileSystem(hadoopConfiguration)
assert(!fs.exists(path))
}
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
}
}
test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") {
withTempPath { dir =>
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
- configuration.set(
+ hadoopConfiguration.set(
SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName)
- configuration.set(
+ hadoopConfiguration.set(
"spark.sql.parquet.output.committer.class",
classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName)
@@ -436,8 +436,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
assert(message === "Intentional exception for testing purposes")
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
}
}
}
@@ -455,11 +455,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
test("SPARK-7837 Do not close output writer twice when commitTask() fails") {
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
// Using a output committer that always fail when committing a task, so that both
// `commitTask()` and `abortTask()` are invoked.
- configuration.set(
+ hadoopConfiguration.set(
"spark.sql.parquet.output.committer.class",
classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName)
@@ -483,8 +483,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index ed8bafb10c..7bac8609e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -517,7 +517,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
}
val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
- val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
withTempPath { dir =>
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index a379523d67..9edbb52268 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.util.Utils
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
+ import testImplicits._
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -40,22 +41,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
- ctx.catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
+ checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
- ctx.catalog.unregisterTable(Seq("tmp"))
+ sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@@ -118,9 +119,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
- val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
- val df2 = ctx.read.parquet(file.getCanonicalPath)
+ val df2 = sqlContext.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
}
}
@@ -129,12 +130,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
// delete summary files, so if we don't merge part-files, one column will not be included.
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
- assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
+ assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -153,9 +154,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
- assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
+ sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -171,19 +172,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
- ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
+ sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
// Disables the global SQL option for schema merging
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
assertResult(2) {
// Disables schema merging via data source option
- ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
+ sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
}
assertResult(3) {
// Enables schema merging via data source option
- ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
+ sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
}
}
}
@@ -193,7 +194,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
withTempPath { dir =>
val basePath = dir.getCanonicalPath
val schema = StructType(Array(StructField("name", DecimalType(10, 5), false)))
- val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45"))))
+ val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45"))))
val df = sqlContext.createDataFrame(rowRDD, schema)
df.write.parquet(basePath)
@@ -203,9 +204,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
}
test("SPARK-10005 Schema merging for nested struct") {
- val sqlContext = _sqlContext
- import sqlContext.implicits._
-
withTempPath { dir =>
val path = dir.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 5dbc7d1630..442fafb12f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils {
- protected def _sqlContext: SQLContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
@@ -43,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -55,7 +54,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
+ withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
}
/**
@@ -67,14 +66,14 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetDataFrame(data) { df =>
- _sqlContext.registerDataFrameAsTable(df, tableName)
+ sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
- _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+ sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 53a0e53fd7..dcbfdca71a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -33,8 +33,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
* without serializing the hashed relation, which does not happen in local mode.
*/
class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
- private var sc: SparkContext = null
- private var sqlContext: SQLContext = null
+ protected var sqlContext: SQLContext = null
/**
* Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
@@ -44,15 +43,14 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
val conf = new SparkConf()
.setMaster("local-cluster[2,1,1024]")
.setAppName("testing")
- sc = new SparkContext(conf)
+ val sc = new SparkContext(conf)
sqlContext = new SQLContext(sc)
sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
- sc.stop()
- sc = null
+ sqlContext.sparkContext.stop()
sqlContext = null
}
@@ -60,7 +58,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
* Test whether the specified broadcast join updates the peak execution memory accumulator.
*/
private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
- AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) {
val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
// Comparison at the end is for broadcast left semi join
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 4c9187a9a7..e5fd9e277f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
@@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
@@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data")
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()).toArray
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index cc649b9bd4..4174ee0550 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -27,9 +27,10 @@ import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
- private lazy val myUpperCaseData = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val myUpperCaseData = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
Row(3, "C"),
@@ -39,8 +40,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))
- private lazy val myLowerCaseData = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val myLowerCaseData = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index a1a617d7b7..c2e0bdac17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
- private lazy val left = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val left = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(2, 100.0),
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
@@ -40,8 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
- private lazy val right = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val right = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, -1.0),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index baa86e320d..3afd762942 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
- private lazy val left = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val left = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
@@ -40,8 +40,8 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
- private lazy val right = ctx.createDataFrame(
- ctx.sparkContext.parallelize(Seq(
+ private lazy val right = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 80006bf077..6afffae161 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -36,7 +36,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("LongSQLMetric should not box Long") {
- val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
+ val l = SQLMetrics.createLongMetric(sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@@ -50,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("Normal accumulator should do boxing") {
// We need this test to make sure BoxingFinder works.
- val l = ctx.sparkContext.accumulator(0L)
+ val l = sparkContext.accumulator(0L)
val f = () => { l += 1L }
BoxingFinder.getClassReader(f.getClass).foreach { cl =>
val boxingFinder = new BoxingFinder()
@@ -71,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
- val previousExecutionIds = ctx.listener.executionIdToData.keySet
+ val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
df.collect()
- ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = ctx.listener.getExecution(executionId).get.jobs
+ val jobs = sqlContext.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
- val metricValues = ctx.listener.getExecutionMetrics(executionId)
+ val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
expectedMetrics.contains(node.id)
}.map { node =>
@@ -474,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = ctx.listener.executionIdToData.keySet
+ val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
// Assume the execution plan is
// PhysicalRDD(nodeId = 0)
person.select('name).write.format("json").save(file.getAbsolutePath)
- ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = ctx.listener.getExecution(executionId).get.jobs
+ val jobs = sqlContext.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= 1)
- val metricValues = ctx.listener.getExecutionMetrics(executionId)
+ val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
assert(metricValues.values.toSeq === Seq(2L))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 80d1e88956..2bbb41ca77 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("basic") {
- val listener = new SQLListener(ctx)
+ val listener = new SQLListener(sqlContext)
val executionId = 0
val df = createTestDataFrame
val accumulatorIds =
@@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
- val listener = new SQLListener(ctx)
+ val listener = new SQLListener(sqlContext)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
- val listener = new SQLListener(ctx)
+ val listener = new SQLListener(sqlContext)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
- val listener = new SQLListener(ctx)
+ val listener = new SQLListener(sqlContext)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index d8c9a08d84..ed710689cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -255,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
}
test("Basic API") {
- assert(ctx.read.jdbc(
+ assert(sqlContext.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
- assert(ctx.read.jdbc(
+ assert(sqlContext.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
- ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
+ sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
+ assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}
@@ -330,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
}
test("test DATE types") {
- val rows = ctx.read.jdbc(
+ val rows = sqlContext.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
@@ -340,8 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
}
test("test DATE types in cache") {
- val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
+ sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
@@ -349,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
}
test("test types for null value") {
- val rows = ctx.read.jdbc(
+ val rows = sqlContext.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
@@ -396,7 +396,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
- val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 5dc3a2c07b..e23ee66931 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -22,13 +22,12 @@ import java.util.Properties
import org.scalatest.BeforeAndAfter
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
+class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
@@ -76,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon
conn1.close()
}
- private lazy val sc = ctx.sparkContext
-
private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
private lazy val arr1x2 = Array[Row](Row.apply("fred", 3))
private lazy val schema2 = StructType(
@@ -91,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
- assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
- assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
+ assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
+ assert(
+ 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
test("CREATE with overwrite") {
- val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
- val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
+ val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
- assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
- assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
- assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
+ assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
- val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
- assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
- assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
+ assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
+ assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
- val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
- assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
- assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
+ assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {
- val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
- val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
intercept[org.apache.spark.SparkException] {
@@ -143,14 +141,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon
test("INSERT to JDBC Datasource") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
- assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 9bc3f6bcf6..6fc9febe49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -26,10 +26,8 @@ import org.apache.spark.sql.execution.datasources.DDLException
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
-
class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
- private lazy val sparkContext = caseInsensitiveContext.sparkContext
private var path: File = null
override def beforeAll(): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index d74d29fb0b..af04079ec8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql._
-
private[sql] abstract class DataSourceTest extends QueryTest {
- protected def _sqlContext: SQLContext
// We want to test some edge cases.
protected lazy val caseInsensitiveContext: SQLContext = {
- val ctx = new SQLContext(_sqlContext.sparkContext)
+ val ctx = new SQLContext(sqlContext.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 084d83f6e9..5b70d258d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
- private lazy val sparkContext = caseInsensitiveContext.sparkContext
private var path: File = null
override def beforeAll(): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index 79b6e9b45c..c9791879ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val df = ctx.range(100).select($"id", lit(1).as("data"))
+ val df = sqlContext.range(100).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- ctx.read.load(path.getCanonicalPath),
+ sqlContext.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
val path = Utils.createTempDir()
path.delete()
- val base = ctx.range(100)
+ val base = sqlContext.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- ctx.read.load(path.getCanonicalPath),
+ sqlContext.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index f18546b4c2..10d2613689 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -28,7 +28,6 @@ import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
- private lazy val sparkContext = caseInsensitiveContext.sparkContext
private var originalDefaultSource: String = null
private var path: File = null
private var df: DataFrame = null
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 3fc02df954..520dea7f7d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -24,11 +24,11 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
* A collection of sample data used in SQL tests.
*/
private[sql] trait SQLTestData { self =>
- protected def _sqlContext: SQLContext
+ protected def sqlContext: SQLContext
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self._sqlContext
+ protected override def _sqlContext: SQLContext = self.sqlContext
}
import internalImplicits._
@@ -37,21 +37,21 @@ private[sql] trait SQLTestData { self =>
// Note: all test data should be lazy because the SQLContext is not set up yet.
protected lazy val emptyTestData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
Seq.empty[Int].map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("emptyTestData")
df
}
protected lazy val testData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("testData")
df
}
protected lazy val testData2: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
@@ -63,7 +63,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val testData3: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
df.registerTempTable("testData3")
@@ -71,14 +71,14 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val negativeData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
df.registerTempTable("negativeData")
df
}
protected lazy val largeAndSmallInts: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
@@ -90,7 +90,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val decimalData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
@@ -102,7 +102,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val binaryData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes, 1) ::
BinaryData("22".getBytes, 5) ::
BinaryData("122".getBytes, 3) ::
@@ -113,7 +113,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val upperCaseData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
@@ -125,7 +125,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val lowerCaseData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
@@ -135,7 +135,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val arrayData: RDD[ArrayData] = {
- val rdd = _sqlContext.sparkContext.parallelize(
+ val rdd = sqlContext.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
rdd.toDF().registerTempTable("arrayData")
@@ -143,7 +143,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val mapData: RDD[MapData] = {
- val rdd = _sqlContext.sparkContext.parallelize(
+ val rdd = sqlContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
@@ -154,13 +154,13 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val repeatedData: RDD[StringData] = {
- val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
+ val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("repeatedData")
rdd
}
protected lazy val nullableRepeatedData: RDD[StringData] = {
- val rdd = _sqlContext.sparkContext.parallelize(
+ val rdd = sqlContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("nullableRepeatedData")
@@ -168,7 +168,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val nullInts: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
@@ -178,7 +178,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val allNulls: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
@@ -188,7 +188,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val nullStrings: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
@@ -197,13 +197,13 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val tableName: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
+ val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
df.registerTempTable("tableName")
df
}
protected lazy val unparsedStrings: RDD[String] = {
- _sqlContext.sparkContext.parallelize(
+ sqlContext.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
@@ -212,13 +212,13 @@ private[sql] trait SQLTestData { self =>
// An RDD with 4 elements and 8 partitions
protected lazy val withEmptyParts: RDD[IntField] = {
- val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+ val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
rdd.toDF().registerTempTable("withEmptyParts")
rdd
}
protected lazy val person: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
df.registerTempTable("person")
@@ -226,7 +226,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val salary: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
df.registerTempTable("salary")
@@ -234,7 +234,7 @@ private[sql] trait SQLTestData { self =>
}
protected lazy val complexData: DataFrame = {
- val df = _sqlContext.sparkContext.parallelize(
+ val df = sqlContext.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
Nil).toDF()
@@ -246,7 +246,7 @@ private[sql] trait SQLTestData { self =>
* Initialize all test data such that all temp tables are properly registered.
*/
def loadTestData(): Unit = {
- assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
+ assert(sqlContext != null, "attempted to initialize test data before SQLContext.")
emptyTestData
testData
testData2
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index dc08306ad9..9214569f18 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.util.Utils
@@ -47,13 +47,13 @@ private[sql] trait SQLTestUtils
with BeforeAndAfterAll
with SQLTestData { self =>
- protected def _sqlContext: SQLContext
+ protected def sparkContext = sqlContext.sparkContext
// Whether to materialize all test data before the first test is run
private var loadTestDataBeforeTests = false
// Shorthand for running a query using our SQLContext
- protected lazy val sql = _sqlContext.sql _
+ protected lazy val sql = sqlContext.sql _
/**
* A helper object for importing SQL implicits.
@@ -63,7 +63,14 @@ private[sql] trait SQLTestUtils
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
- protected override def _sqlContext: SQLContext = self._sqlContext
+ protected override def _sqlContext: SQLContext = self.sqlContext
+
+ // This must live here to preserve binary compatibility with Spark < 1.5.
+ implicit class StringToColumn(val sc: StringContext) {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args: _*))
+ }
+ }
}
/**
@@ -84,8 +91,8 @@ private[sql] trait SQLTestUtils
/**
* The Hadoop configuration used by the active [[SQLContext]].
*/
- protected def configuration: Configuration = {
- _sqlContext.sparkContext.hadoopConfiguration
+ protected def hadoopConfiguration: Configuration = {
+ sparkContext.hadoopConfiguration
}
/**
@@ -96,12 +103,12 @@ private[sql] trait SQLTestUtils
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
- (keys, values).zipped.foreach(_sqlContext.conf.setConfString)
+ val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
+ (keys, values).zipped.foreach(sqlContext.conf.setConfString)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
- case (key, None) => _sqlContext.conf.unsetConf(key)
+ case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
+ case (key, None) => sqlContext.conf.unsetConf(key)
}
}
}
@@ -133,7 +140,7 @@ private[sql] trait SQLTestUtils
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
- try f finally tableNames.foreach(_sqlContext.dropTempTable)
+ try f finally tableNames.foreach(sqlContext.dropTempTable)
}
/**
@@ -142,7 +149,7 @@ private[sql] trait SQLTestUtils
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
- _sqlContext.sql(s"DROP TABLE IF EXISTS $name")
+ sqlContext.sql(s"DROP TABLE IF EXISTS $name")
}
}
}
@@ -155,12 +162,12 @@ private[sql] trait SQLTestUtils
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
try {
- _sqlContext.sql(s"CREATE DATABASE $dbName")
+ sqlContext.sql(s"CREATE DATABASE $dbName")
} catch { case cause: Throwable =>
fail("Failed to create temporary database", cause)
}
- try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
+ try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
}
/**
@@ -168,8 +175,8 @@ private[sql] trait SQLTestUtils
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
- _sqlContext.sql(s"USE $db")
- try f finally _sqlContext.sql(s"USE default")
+ sqlContext.sql(s"USE $db")
+ try f finally sqlContext.sql(s"USE default")
}
/**
@@ -177,7 +184,7 @@ private[sql] trait SQLTestUtils
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
- DataFrame(_sqlContext, plan)
+ DataFrame(sqlContext, plan)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index d23c6a0732..963d10eed6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.test
-import org.apache.spark.sql.{ColumnName, SQLContext}
+import org.apache.spark.sql.SQLContext
/**
@@ -36,9 +36,7 @@ trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
- protected def ctx: TestSQLContext = _ctx
- protected def sqlContext: TestSQLContext = _ctx
- protected override def _sqlContext: SQLContext = _ctx
+ protected def sqlContext: SQLContext = _ctx
/**
* Initialize the [[TestSQLContext]].
@@ -64,15 +62,4 @@ trait SharedSQLContext extends SQLTestUtils {
super.afterAll()
}
}
-
- /**
- * Converts $"col name" into an [[Column]].
- * @since 1.3.0
- */
- // This must be duplicated here to preserve binary compatibility with Spark < 1.5.
- implicit class StringToColumn(val sc: StringContext) {
- def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args: _*))
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 92ef2f7d74..d99d191ebe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -47,6 +47,6 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
}
private object testData extends SQLTestData {
- protected override def _sqlContext: SQLContext = self
+ protected override def sqlContext: SQLContext = self
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 57fea5d8db..77f43f9270 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
-import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.{SQLContext, SQLConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.CacheTableCommand
@@ -51,6 +51,11 @@ object TestHive
// SPARK-8910
.set("spark.ui.enabled", "false")))
+trait TestHiveSingleton {
+ protected val sqlContext: SQLContext = TestHive
+ protected val hiveContext: TestHiveContext = TestHive
+}
+
/**
* A locally running test instance of Spark's Hive execution engine.
*
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 39d315aaea..9adb3780a2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -19,14 +19,14 @@ package org.apache.spark.sql.hive
import java.io.File
-import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest}
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode}
import org.apache.spark.storage.RDDBlockId
import org.apache.spark.util.Utils
-class CachedTableSuite extends QueryTest {
+class CachedTableSuite extends QueryTest with TestHiveSingleton {
+ import hiveContext._
def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
@@ -95,18 +95,18 @@ class CachedTableSuite extends QueryTest {
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- TestHive.uncacheTable("src")
+ hiveContext.uncacheTable("src")
}
}
test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") {
- TestHive.sql("CACHE TABLE src")
+ sql("CACHE TABLE src")
assertCached(table("src"))
- assert(TestHive.isCached("src"), "Table 'src' should be cached")
+ assert(hiveContext.isCached("src"), "Table 'src' should be cached")
- TestHive.sql("UNCACHE TABLE src")
+ sql("UNCACHE TABLE src")
assertCached(table("src"), 0)
- assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
+ assert(!hiveContext.isCached("src"), "Table 'src' should not be cached")
}
test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index 30f5313d2b..cf73783693 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -22,12 +22,12 @@ import scala.util.Try
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.catalyst.util.quietly
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.{AnalysisException, QueryTest}
-class ErrorPositionSuite extends QueryTest with BeforeAndAfter {
+class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter {
+ import hiveContext.implicits._
before {
Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes")
@@ -122,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter {
test(name) {
val error = intercept[AnalysisException] {
- quietly(sql(query))
+ quietly(hiveContext.sql(query))
}
assert(!error.getMessage.contains("Seq("))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index fb10f8583d..2e5cae415e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -19,24 +19,25 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.scalatest.BeforeAndAfterAll
// TODO ideally we should put the test suite into the package `sql`, as
// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
// support the `cube` or `rollup` yet.
-class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll {
+class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
+ import hiveContext.implicits._
+ import hiveContext.sql
+
private var testData: DataFrame = _
override def beforeAll() {
testData = Seq((1, 2), (2, 4)).toDF("a", "b")
- TestHive.registerDataFrameAsTable(testData, "mytable")
+ hiveContext.registerDataFrameAsTable(testData, "mytable")
}
override def afterAll(): Unit = {
- TestHive.dropTempTable("mytable")
+ hiveContext.dropTempTable("mytable")
}
test("rollup") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala
index 52e782768c..f621367eb5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.{Row, QueryTest}
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-
-class HiveDataFrameJoinSuite extends QueryTest {
+class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton {
+ import hiveContext.implicits._
// We should move this into SQL package if we make case sensitivity configurable in SQL.
test("join - self join auto resolve ambiguity with case insensitivity") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
index c177cbdd99..2c98f1c3cc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-class HiveDataFrameWindowSuite extends QueryTest {
+class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton {
+ import hiveContext.implicits._
+ import hiveContext.sql
test("reuse window partitionBy") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 574624d501..107457f79e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -19,18 +19,15 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable}
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
-import org.apache.spark.sql.sources.DataSourceTest
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
-import org.apache.spark.sql.{Row, SaveMode, SQLContext}
-import org.apache.spark.{Logging, SparkFunSuite}
-
-class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
+class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton {
+ import hiveContext.implicits._
test("struct field should accept underscore in sub-column name") {
val hiveTypeStr = "struct<a: int, b_1: string, c: string>"
@@ -46,14 +43,15 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
}
test("duplicated metastore relations") {
- val df = sql("SELECT * FROM src")
+ val df = hiveContext.sql("SELECT * FROM src")
logInfo(df.queryExecution.toString)
df.as('a).join(df.as('b), $"a.key" === $"b.key")
}
}
-class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
- override def _sqlContext: SQLContext = TestHive
+class DataSourceWithHiveMetastoreCatalogSuite
+ extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import hiveContext._
import testImplicits._
private val testDF = range(1, 3).select(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index fe0db5228d..5596ec6882 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
-import org.apache.spark.sql.{QueryTest, Row, SQLContext}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.{QueryTest, Row}
case class Cases(lower: String, UPPER: String)
-class HiveParquetSuite extends QueryTest with ParquetTest {
- private val ctx = TestHive
- override def _sqlContext: SQLContext = ctx
+class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton {
test("Case insensitive attribute names") {
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
@@ -53,7 +51,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("Converting Hive to Parquet Table via saveAsParquetFile") {
withTempPath { dir =>
sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
- ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
+ hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
checkAnswer(
sql("SELECT * FROM src ORDER BY key"),
@@ -66,7 +64,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
- ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
+ hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index dc2d85f486..84f3db44ec 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
import org.apache.spark._
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{SQLContext, QueryTest}
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.DecimalType
@@ -272,7 +272,11 @@ object SparkSQLConfTest extends Logging {
}
}
-object SPARK_9757 extends QueryTest with Logging {
+object SPARK_9757 extends QueryTest {
+ import org.apache.spark.sql.functions._
+
+ protected var sqlContext: SQLContext = _
+
def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")
@@ -282,10 +286,9 @@ object SPARK_9757 extends QueryTest with Logging {
.set("spark.sql.hive.metastore.jars", "maven"))
val hiveContext = new TestHiveContext(sparkContext)
+ sqlContext = hiveContext
import hiveContext.implicits._
- import org.apache.spark.sql.functions._
-
val dir = Utils.createTempDir()
dir.delete()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index d33e81227d..80a61f82fd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -24,28 +24,25 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.{QueryTest, _}
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-/* Implicits */
-import org.apache.spark.sql.hive.test.TestHive._
-
case class TestData(key: Int, value: String)
case class ThreeCloumntable(key: Int, value: String, key1: String)
-class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
+class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter {
+ import hiveContext.implicits._
+ import hiveContext.sql
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
before {
// Since every we are doing tests for DDL statements,
// it is better to reset before every test.
- TestHive.reset()
+ hiveContext.reset()
// Register the testData, which will be used in every test.
testData.registerTempTable("testData")
}
@@ -96,9 +93,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
test("SPARK-4052: scala.collection.Map as value type of MapType") {
val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil)
- val rowRDD = TestHive.sparkContext.parallelize(
+ val rowRDD = hiveContext.sparkContext.parallelize(
(1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i"))))
- val df = TestHive.createDataFrame(rowRDD, schema)
+ val df = hiveContext.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m MAP <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -169,8 +166,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
test("Insert ArrayType.containsNull == false") {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
- val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
- val df = TestHive.createDataFrame(rowRDD, schema)
+ val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
+ val df = hiveContext.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array <STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
@@ -185,9 +182,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
test("Insert MapType.valueContainsNull == false") {
val schema = StructType(Seq(
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
- val rowRDD = TestHive.sparkContext.parallelize(
+ val rowRDD = hiveContext.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
- val df = TestHive.createDataFrame(rowRDD, schema)
+ val df = hiveContext.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -202,9 +199,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
test("Insert StructType.fields.exists(_.nullable == false)") {
val schema = StructType(Seq(
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
- val rowRDD = TestHive.sparkContext.parallelize(
+ val rowRDD = hiveContext.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
- val df = TestHive.createDataFrame(rowRDD, schema)
+ val df = hiveContext.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct <f: STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
@@ -217,11 +214,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
}
test("SPARK-5498:partition schema does not match table schema") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
(1 to 10).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
- val testDatawithNull = TestHive.sparkContext.parallelize(
+ val testDatawithNull = hiveContext.sparkContext.parallelize(
(1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF()
val tmpDir = Utils.createTempDir()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
index d3388a9429..579631df77 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -19,17 +19,15 @@ package org.apache.spark.sql.hive
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.Row
-class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
+class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
+ import hiveContext._
+ import hiveContext.implicits._
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
- val df =
- sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")
+ val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")
override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 20a50586d5..bf0db08490 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -22,15 +22,11 @@ import java.io.{IOException, File}
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
-import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable}
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
@@ -39,10 +35,9 @@ import org.apache.spark.util.Utils
/**
* Tests for persisting tables created though the data sources API into the metastore.
*/
-class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
- with Logging {
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import hiveContext._
+ import hiveContext.implicits._
var jsonFilePath: String = _
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 997c667ec0..f16c257ab5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -17,20 +17,16 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode}
+import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode}
-class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
- override val _sqlContext: HiveContext = TestHive
- private val sqlContext = _sqlContext
-
- private val df = sqlContext.range(10).coalesce(1)
+class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ private lazy val df = sqlContext.range(10).coalesce(1)
private def checkTablePath(dbName: String, tableName: String): Unit = {
- // val hiveContext = sqlContext.asInstanceOf[HiveContext]
- val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName)
- val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName
+ val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName)
+ val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName
assert(metastoreTable.serdeProperties("path") === expectedPath)
}
@@ -220,7 +216,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
df.write.parquet(s"$path/p=2")
sql("ALTER TABLE t ADD PARTITION (p=2)")
- sqlContext.refreshTable("t")
+ hiveContext.refreshTable("t")
checkAnswer(
sqlContext.table("t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
@@ -252,7 +248,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
df.write.parquet(s"$path/p=2")
sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)")
- sqlContext.refreshTable(s"$db.t")
+ hiveContext.refreshTable(s"$db.t")
checkAnswer(
sqlContext.table(s"$db.t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
index 91d7a48208..49aab85cf1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
@@ -18,38 +18,20 @@
package org.apache.spark.sql.hive
import java.sql.Timestamp
-import java.util.{Locale, TimeZone}
import org.apache.hadoop.hive.conf.HiveConf
-import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.{Row, SQLConf, SQLContext}
-
-class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with BeforeAndAfterAll {
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+import org.apache.spark.sql.{Row, SQLConf}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton {
/**
* Set the staging directory (and hence path to ignore Parquet files under)
* to that set by [[HiveConf.ConfVars.STAGINGDIR]].
*/
private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)
- private val originalTimeZone = TimeZone.getDefault
- private val originalLocale = Locale.getDefault
-
- protected override def beforeAll(): Unit = {
- TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
- Locale.setDefault(Locale.US)
- }
-
- override protected def afterAll(): Unit = {
- TimeZone.setDefault(originalTimeZone)
- Locale.setDefault(originalLocale)
- }
-
override protected def logParquetSchema(path: String): Unit = {
val schema = readParquetSchema(path, { path =>
!path.getName.startsWith("_") && !path.getName.startsWith(stagingDir)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
index 1cc8a93e83..f542a5a025 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
@@ -18,22 +18,18 @@
package org.apache.spark.sql.hive
import com.google.common.io.Files
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.util.Utils
+import org.apache.spark.sql.{QueryTest, _}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import hiveContext.implicits._
-class QueryPartitionSuite extends QueryTest with SQLTestUtils {
-
- private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
- import ctx.implicits._
-
- protected def _sqlContext = ctx
-
- test("SPARK-5068: query data when path doesn't exist"){
+ test("SPARK-5068: query data when path doesn't exist") {
withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) {
- val testData = ctx.sparkContext.parallelize(
+ val testData = sparkContext.parallelize(
(1 to 10).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index e4fec7e2c8..6a692d6fce 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -17,24 +17,15 @@
package org.apache.spark.sql.hive
-import org.scalatest.BeforeAndAfterAll
-
import scala.reflect.ClassTag
import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.execution._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
-
- private lazy val ctx: HiveContext = {
- val ctx = org.apache.spark.sql.hive.test.TestHive
- ctx.reset()
- ctx.cacheTables = false
- ctx
- }
-
- import ctx.sql
+class StatisticsSuite extends QueryTest with TestHiveSingleton {
+ import hiveContext.sql
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
@@ -54,9 +45,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
}
}
- // Ensure session state is initialized.
- ctx.parseSql("use default")
-
assertAnalyzeCommand(
"ANALYZE TABLE Table1 COMPUTE STATISTICS",
classOf[HiveNativeCommand])
@@ -80,7 +68,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
test("analyze MetastoreRelations") {
def queryTotalSize(tableName: String): BigInt =
- ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes
+ hiveContext.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes
// Non-partitioned table
sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
@@ -114,7 +102,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
|SELECT * FROM src
""".stripMargin).collect()
- assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes)
+ assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes)
sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan")
@@ -125,9 +113,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
// Try to analyze a temp table
sql("""SELECT * FROM src""").registerTempTable("tempTable")
intercept[UnsupportedOperationException] {
- ctx.analyze("tempTable")
+ hiveContext.analyze("tempTable")
}
- ctx.catalog.unregisterTable(Seq("tempTable"))
+ hiveContext.catalog.unregisterTable(Seq("tempTable"))
}
test("estimates the size of a test MetastoreRelation") {
@@ -155,8 +143,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
val sizes = df.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes
}
- assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold
- && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold,
+ assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold
+ && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")
// Using `sparkPlan` because for relevant patterns in HashJoin to be
@@ -167,8 +155,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(df, expectedAnswer) // check correctness of output
- ctx.conf.settings.synchronized {
- val tmp = ctx.conf.autoBroadcastJoinThreshold
+ hiveContext.conf.settings.synchronized {
+ val tmp = hiveContext.conf.autoBroadcastJoinThreshold
sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""")
df = sql(query)
@@ -211,8 +199,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
.isAssignableFrom(r.getClass) =>
r.statistics.sizeInBytes
}
- assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold
- && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold,
+ assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold
+ && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")
// Using `sparkPlan` because for relevant patterns in HashJoin to be
@@ -225,8 +213,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(df, answer) // check correctness of output
- ctx.conf.settings.synchronized {
- val tmp = ctx.conf.autoBroadcastJoinThreshold
+ hiveContext.conf.settings.synchronized {
+ val tmp = hiveContext.conf.autoBroadcastJoinThreshold
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
df = sql(leftSemiJoinQuery)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 7ee1c8d13a..3ab4576811 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -18,18 +18,18 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.hive.test.TestHiveSingleton
case class FunctionResult(f1: String, f2: String)
-class UDFSuite extends QueryTest {
- private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
+class UDFSuite extends QueryTest with TestHiveSingleton {
test("UDF case insensitive") {
- ctx.udf.register("random0", () => { Math.random() })
- ctx.udf.register("RANDOM1", () => { Math.random() })
- ctx.udf.register("strlenScala", (_: String).length + (_: Int))
- assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
- assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
- assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
+ hiveContext.udf.register("random0", () => { Math.random() })
+ hiveContext.udf.register("RANDOM1", () => { Math.random() })
+ hiveContext.udf.register("strlenScala", (_: String).length + (_: Int))
+ assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
+ assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4886a85948..b126ec455f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,19 +17,15 @@
package org.apache.spark.sql.hive.execution
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
- override def _sqlContext: SQLContext = TestHive
- protected val sqlContext = _sqlContext
- import sqlContext.implicits._
+abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
var originalUseAggregate2: Boolean = _
@@ -69,7 +65,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
data2.write.saveAsTable("agg2")
val emptyDF = sqlContext.createDataFrame(
- sqlContext.sparkContext.emptyRDD[Row],
+ sparkContext.emptyRDD[Row],
StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
emptyDF.registerTempTable("emptyTable")
@@ -597,7 +593,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue
sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt")
}
- override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
(0 to 2).foreach { fallbackStartsAt =>
sqlContext.setConf(
"spark.sql.TungstenAggregate.testFallbackStartsAt",
@@ -605,6 +601,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
+ // todo: remove it?
val newActual = DataFrame(sqlContext, actual.logicalPlan)
QueryTest.checkAnswer(newActual, expectedAnswer) match {
@@ -626,12 +623,12 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue
}
// Override it to make sure we call the actually overridden checkAnswer.
- override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+ override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(df, Seq(expectedAnswer))
}
// Override it to make sure we call the actually overridden checkAnswer.
- override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+ override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
checkAnswer(df, expectedAnswer.collect())
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 4d45249d9c..aa95ba94fa 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
import org.scalatest.{BeforeAndAfterAll, GivenWhenThen}
-import org.apache.spark.{Logging, SparkFunSuite}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
@@ -42,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive
* configured using system properties.
*/
abstract class HiveComparisonTest
- extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
+ extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen {
/**
* When set, any cache files that result in test failures will be deleted. Used when the test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 11d7a872df..94162da4ea 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -17,17 +17,14 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.{SQLContext, QueryTest}
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.hive.test.TestHiveSingleton
/**
* A set of tests that validates support for Hive Explain command.
*/
-class HiveExplainSuite extends QueryTest with SQLTestUtils {
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("explain extended command") {
checkExistence(sql(" explain select * from src where key=123 "), true,
@@ -83,7 +80,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils {
test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") {
withTempTable("jt") {
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
- read.json(rdd).registerTempTable("jt")
+ hiveContext.read.json(rdd).registerTempTable("jt")
val outputs = sql(
s"""
|EXPLAIN EXTENDED
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala
index efbef68cd4..0d4c7f86b3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala
@@ -18,14 +18,16 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.{Row, QueryTest}
-import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton}
/**
* A set of tests that validates commands can also be queried by like a table
*/
-class HiveOperatorQueryableSuite extends QueryTest {
+class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton {
+ import hiveContext._
+
test("SPARK-5324 query result of describe command") {
- loadTestTable("src")
+ hiveContext.loadTestTable("src")
// register a describe command to be a temp table
sql("desc src").registerTempTable("mydesc")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
index ba56a8a6b6..cd055f9eca 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
@@ -21,11 +21,11 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-class HivePlanTest extends QueryTest {
- import TestHive._
- import TestHive.implicits._
+class HivePlanTest extends QueryTest with TestHiveSingleton {
+ import hiveContext.sql
+ import hiveContext.implicits._
test("udf constant folding") {
Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 9c10ffe111..d9ba895e1e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.io.Writable
import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf}
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.util.Utils
@@ -43,10 +43,10 @@ case class ListStringCaseClass(l: Seq[String])
/**
* A test suite for Hive custom UDFs.
*/
-class HiveUDFSuite extends QueryTest {
+class HiveUDFSuite extends QueryTest with TestHiveSingleton {
- import TestHive.{udf, sql}
- import TestHive.implicits._
+ import hiveContext.{udf, sql}
+ import hiveContext.implicits._
test("spark sql udf test that returns a struct") {
udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
@@ -123,12 +123,12 @@ class HiveUDFSuite extends QueryTest {
| "value", value)).value FROM src
""".stripMargin), Seq(Row("val_0")))
}
- val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED)
- TestHive.setConf(SQLConf.CODEGEN_ENABLED, true)
+ val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED)
+ hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true)
testOrderInStruct()
- TestHive.setConf(SQLConf.CODEGEN_ENABLED, false)
+ hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false)
testOrderInStruct()
- TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
+ hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
}
test("SPARK-6409 UDAFAverage test") {
@@ -137,7 +137,7 @@ class HiveUDFSuite extends QueryTest {
sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),
Seq(Row(1.0, 260.182)))
sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
- TestHive.reset()
+ hiveContext.reset()
}
test("SPARK-2693 udaf aggregates test") {
@@ -157,7 +157,7 @@ class HiveUDFSuite extends QueryTest {
}
test("UDFIntegerToString") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF()
testData.registerTempTable("integerTable")
@@ -168,11 +168,11 @@ class HiveUDFSuite extends QueryTest {
Seq(Row("1"), Row("2")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFToListString") {
- val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
+ val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
testData.registerTempTable("inputTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'")
@@ -183,11 +183,11 @@ class HiveUDFSuite extends QueryTest {
"JVM type erasure makes spark fail to catch a component type in List<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFToListInt") {
- val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
+ val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
testData.registerTempTable("inputTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'")
@@ -198,11 +198,11 @@ class HiveUDFSuite extends QueryTest {
"JVM type erasure makes spark fail to catch a component type in List<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFToStringIntMap") {
- val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
+ val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
testData.registerTempTable("inputTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " +
@@ -214,11 +214,11 @@ class HiveUDFSuite extends QueryTest {
"JVM type erasure makes spark fail to catch key and value types in Map<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFToIntIntMap") {
- val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
+ val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
testData.registerTempTable("inputTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " +
@@ -230,11 +230,11 @@ class HiveUDFSuite extends QueryTest {
"JVM type erasure makes spark fail to catch key and value types in Map<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFListListInt") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF()
@@ -246,11 +246,11 @@ class HiveUDFSuite extends QueryTest {
Seq(Row(0), Row(2), Row(13)))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFListString") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
ListStringCaseClass(Seq("a", "b", "c")) ::
ListStringCaseClass(Seq("d", "e")) :: Nil).toDF()
testData.registerTempTable("listStringTable")
@@ -261,11 +261,11 @@ class HiveUDFSuite extends QueryTest {
Seq(Row("a,b,c"), Row("d,e")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFStringString") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF()
testData.registerTempTable("stringTable")
@@ -280,11 +280,11 @@ class HiveUDFSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF")
- TestHive.reset()
+ hiveContext.reset()
}
test("UDFTwoListList") {
- val testData = TestHive.sparkContext.parallelize(
+ val testData = hiveContext.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
@@ -297,7 +297,7 @@ class HiveUDFSuite extends QueryTest {
Seq(Row("0, 0"), Row("2, 2"), Row("13, 13")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
- TestHive.reset()
+ hiveContext.reset()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 1ff1d9a293..8126d02335 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -26,9 +26,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation}
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.test.SQLTestUtils
@@ -65,12 +63,12 @@ class MyDialect extends DefaultParserDialect
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
* valid, but Hive currently cannot execute it.
*/
-class SQLQuerySuite extends QueryTest with SQLTestUtils {
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import hiveContext._
+ import hiveContext.implicits._
test("UDTF") {
- sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
+ sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
// The function source code can be found at:
// https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
sql(
@@ -509,19 +507,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
Row(1))
- checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"),
- Seq.empty[Row])
+
+ sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested")
checkAnswer(
sql("SELECT * FROM test_ctas_1234"),
sql("SELECT * FROM nested").collect().toSeq)
intercept[AnalysisException] {
- sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect()
+ sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect()
}
}
test("test CTAS") {
- checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row])
+ sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src")
checkAnswer(
sql("SELECT key, value FROM test_ctas_123 ORDER BY key"),
sql("SELECT key, value FROM src ORDER BY key").collect().toSeq)
@@ -614,7 +612,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
val rowRdd = sparkContext.parallelize(row :: Nil)
- TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable")
+ hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes
@@ -1044,10 +1042,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
val thread = new Thread {
override def run() {
// To make sure this test works, this jar should not be loaded in another place.
- TestHive.sql(
- s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}")
+ sql(
+ s"ADD JAR ${hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}")
try {
- TestHive.sql(
+ sql(
"""
|CREATE TEMPORARY FUNCTION example_max
|AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax'
@@ -1097,21 +1095,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") {
val df =
- TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01")))
+ createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01")))
df.toDF("id", "datef").registerTempTable("test_SPARK8588")
checkAnswer(
- TestHive.sql(
+ sql(
"""
|select id, concat(year(datef))
|from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year')
""".stripMargin),
Row(1, "2014") :: Row(2, "2015") :: Nil
)
- TestHive.dropTempTable("test_SPARK8588")
+ dropTempTable("test_SPARK8588")
}
test("SPARK-9371: fix the support for special chars in column names for hive context") {
- TestHive.read.json(TestHive.sparkContext.makeRDD(
+ read.json(sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
.registerTempTable("t")
@@ -1142,8 +1140,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
test("specifying database name for a temporary table is not allowed") {
withTempPath { dir =>
val path = dir.getCanonicalPath
- val df =
- sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
+ val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
df
.write
.format("parquet")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index 9aca40f15a..cb8d0fca8e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -22,17 +22,14 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest}
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.types.StringType
-class ScriptTransformationSuite extends SparkPlanTest {
-
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton {
+ import hiveContext.implicits._
private val noSerdeIOSchema = HiveScriptIOSchema(
inputRowFormat = Seq.empty,
@@ -59,7 +56,7 @@ class ScriptTransformationSuite extends SparkPlanTest {
output = Seq(AttributeReference("a", StringType)()),
child = child,
ioschema = noSerdeIOSchema
- )(TestHive),
+ )(hiveContext),
rowsDf.collect())
}
@@ -73,7 +70,7 @@ class ScriptTransformationSuite extends SparkPlanTest {
output = Seq(AttributeReference("a", StringType)()),
child = child,
ioschema = serdeIOSchema
- )(TestHive),
+ )(hiveContext),
rowsDf.collect())
}
@@ -88,7 +85,7 @@ class ScriptTransformationSuite extends SparkPlanTest {
output = Seq(AttributeReference("a", StringType)()),
child = ExceptionInjectingOperator(child),
ioschema = noSerdeIOSchema
- )(TestHive),
+ )(hiveContext),
rowsDf.collect())
}
assert(e.getMessage().contains("intentional exception"))
@@ -105,7 +102,7 @@ class ScriptTransformationSuite extends SparkPlanTest {
output = Seq(AttributeReference("a", StringType)()),
child = ExceptionInjectingOperator(child),
ioschema = serdeIOSchema
- )(TestHive),
+ )(hiveContext),
rowsDf.collect())
}
assert(e.getMessage().contains("intentional exception"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
index deec0048d2..9a299c3f9d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -24,10 +24,9 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest
import org.apache.spark.sql.types._
class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
- override val dataSourceName: String = classOf[DefaultSource].getCanonicalName
+ import testImplicits._
- import sqlContext._
- import sqlContext.implicits._
+ override val dataSourceName: String = classOf[DefaultSource].getCanonicalName
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
@@ -48,7 +47,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
checkQueries(
- read.options(Map(
+ hiveContext.read.options(Map(
"path" -> file.getCanonicalPath,
"dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load())
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
index a46ca9a2c9..52e09f9496 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
@@ -18,19 +18,17 @@
package org.apache.spark.sql.hive.orc
import java.io.File
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
-import org.apache.spark.util.Utils
-import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
+import org.scalatest.BeforeAndAfterAll
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.util.Utils
+
// The data where the partitioning key exists only in the directory structure.
case class OrcParData(intField: Int, stringField: String)
@@ -38,7 +36,10 @@ case class OrcParData(intField: Int, stringField: String)
case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot
-class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll {
+class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
+ import hiveContext._
+ import hiveContext.implicits._
+
val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal
def withTempDir(f: File => Unit): Unit = {
@@ -58,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll {
}
protected def withTempTable(tableName: String)(f: => Unit): Unit = {
- try f finally TestHive.dropTempTable(tableName)
+ try f finally hiveContext.dropTempTable(tableName)
}
protected def makePartitionDir(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 80c38084f2..7a34cf731b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -21,12 +21,14 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
case class OrcData(intField: Int, stringField: String)
-abstract class OrcSuite extends QueryTest with BeforeAndAfterAll {
+abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll {
+ import hiveContext._
+
var orcTableDir: File = null
var orcTableAsDir: File = null
@@ -156,7 +158,7 @@ class OrcSourceSuite extends OrcSuite {
override def beforeAll(): Unit = {
super.beforeAll()
- sql(
+ hiveContext.sql(
s"""CREATE TEMPORARY TABLE normal_orc_source
|USING org.apache.spark.sql.hive.orc
|OPTIONS (
@@ -164,7 +166,7 @@ class OrcSourceSuite extends OrcSuite {
|)
""".stripMargin)
- sql(
+ hiveContext.sql(
s"""CREATE TEMPORARY TABLE normal_orc_as_source
|USING org.apache.spark.sql.hive.orc
|OPTIONS (
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index f7ba20ff41..88a0ed5117 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -22,15 +22,12 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.hive.test.TestHiveSingleton
-private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
- protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
- protected val sqlContext = _sqlContext
- import sqlContext.implicits._
- import sqlContext.sparkContext
+private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
/**
* Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f`
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 34d3434569..6842ec2b5e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -19,15 +19,11 @@ package org.apache.spark.sql.hive
import java.io.File
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD}
import org.apache.spark.sql.hive.execution.HiveTableScan
-import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
@@ -58,6 +54,8 @@ case class ParquetDataWithKeyAndComplexTypes(
* built in parquet support.
*/
class ParquetMetastoreSuite extends ParquetPartitioningTest {
+ import hiveContext._
+
override def beforeAll(): Unit = {
super.beforeAll()
dropTables("partitioned_parquet",
@@ -536,6 +534,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
* A suite of tests for the Parquet support through the data sources API.
*/
class ParquetSourceSuite extends ParquetPartitioningTest {
+ import testImplicits._
+ import hiveContext._
+
override def beforeAll(): Unit = {
super.beforeAll()
dropTables("partitioned_parquet",
@@ -684,9 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
/**
* A collection of tests for parquet data with various forms of partitioning.
*/
-abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
- override def _sqlContext: SQLContext = TestHive
- protected val sqlContext = _sqlContext
+abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import testImplicits._
var partitionedTableDir: File = null
var normalTableDir: File = null
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index b4640b1616..dc0531a6d4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -18,16 +18,13 @@
package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
-class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
- override def _sqlContext: SQLContext = TestHive
- private val sqlContext = _sqlContext
+class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton {
// When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
index 8ca3a17085..1945b15002 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
@@ -28,8 +28,6 @@ import org.apache.spark.sql.types._
class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
override val dataSourceName: String = "json"
- import sqlContext._
-
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
@@ -47,7 +45,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
checkQueries(
- read.format(dataSourceName)
+ hiveContext.read.format(dataSourceName)
.option("dataSchema", dataSchemaWithPartition.json)
.load(file.getCanonicalPath))
}
@@ -65,14 +63,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
val data =
Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) ::
Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil
- val df = createDataFrame(sparkContext.parallelize(data), schema)
+ val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema)
// Write the data out.
df.write.format(dataSourceName).save(file.getCanonicalPath)
// Read it back and check the result.
checkAnswer(
- read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
+ hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
df
)
}
@@ -90,14 +88,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
Row(new BigDecimal("10.02")) ::
Row(new BigDecimal("20000.99")) ::
Row(new BigDecimal("10000")) :: Nil
- val df = createDataFrame(sparkContext.parallelize(data), schema)
+ val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema)
// Write the data out.
df.write.format(dataSourceName).save(file.getCanonicalPath)
// Read it back and check the result.
checkAnswer(
- read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
+ hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath),
df
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
index 06dadbb5fe..08c3c17973 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -28,10 +28,9 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
- override val dataSourceName: String = "parquet"
+ import testImplicits._
- import sqlContext._
- import sqlContext.implicits._
+ override val dataSourceName: String = "parquet"
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
@@ -51,7 +50,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
checkQueries(
- read.format(dataSourceName)
+ hiveContext.read.format(dataSourceName)
.option("dataSchema", dataSchemaWithPartition.json)
.load(file.getCanonicalPath))
}
@@ -69,7 +68,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
.format("parquet")
.save(s"${dir.getCanonicalPath}/_temporary")
- checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
+ checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect())
}
}
@@ -97,7 +96,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
// This shouldn't throw anything.
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
- checkAnswer(read.format("parquet").load(path), df)
+ checkAnswer(hiveContext.read.format("parquet").load(path), df)
}
}
@@ -107,7 +106,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
// Parquet doesn't allow field names with spaces. Here we are intentionally making an
// exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger
// the bug. Please refer to spark-8079 for more details.
- range(1, 10)
+ hiveContext.range(1, 10)
.withColumnRenamed("id", "a b")
.write
.format("parquet")
@@ -125,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
val summaryPath = new Path(path, "_metadata")
val commonSummaryPath = new Path(path, "_common_metadata")
- val fs = summaryPath.getFileSystem(configuration)
+ val fs = summaryPath.getFileSystem(hadoopConfiguration)
fs.delete(summaryPath, true)
fs.delete(commonSummaryPath, true)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index e8975e5f5c..1125ca6701 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -25,8 +25,6 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName
- import sqlContext._
-
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
@@ -44,7 +42,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
checkQueries(
- read.format(dataSourceName)
+ hiveContext.read.format(dataSourceName)
.option("dataSchema", dataSchemaWithPartition.json)
.load(file.getCanonicalPath))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 7966b43596..2ad2618dfc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -28,14 +28,12 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
-abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
- override def _sqlContext: SQLContext = TestHive
- protected val sqlContext = _sqlContext
+abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton {
import sqlContext.implicits._
val dataSourceName: String
@@ -504,17 +502,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
test("SPARK-8578 specified custom output committer will not be used to append data") {
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
try {
val df = sqlContext.range(1, 10).toDF("i")
withTempPath { dir =>
df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath)
- configuration.set(
+ hadoopConfiguration.set(
SQLConf.OUTPUT_COMMITTER_CLASS.key,
classOf[AlwaysFailOutputCommitter].getName)
// Since Parquet has its own output committer setting, also set it
// to AlwaysFailParquetOutputCommitter at here.
- configuration.set("spark.sql.parquet.output.committer.class",
+ hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
classOf[AlwaysFailParquetOutputCommitter].getName)
// Because there data already exists,
// this append should succeed because we will use the output committer associated
@@ -533,12 +531,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
}
withTempPath { dir =>
- configuration.set(
+ hadoopConfiguration.set(
SQLConf.OUTPUT_COMMITTER_CLASS.key,
classOf[AlwaysFailOutputCommitter].getName)
// Since Parquet has its own output committer setting, also set it
// to AlwaysFailParquetOutputCommitter at here.
- configuration.set("spark.sql.parquet.output.committer.class",
+ hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
classOf[AlwaysFailParquetOutputCommitter].getName)
// Because there is no existing data,
// this append will fail because AlwaysFailOutputCommitter is used when we do append
@@ -549,8 +547,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
}
}
@@ -570,7 +568,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
test("SPARK-9899 Disable customized output committer when speculation is on") {
- val clonedConf = new Configuration(configuration)
+ val clonedConf = new Configuration(hadoopConfiguration)
val speculationEnabled =
sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false)
@@ -580,7 +578,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
sqlContext.sparkContext.conf.set("spark.speculation", "true")
// Uses a customized output committer which always fails
- configuration.set(
+ hadoopConfiguration.set(
SQLConf.OUTPUT_COMMITTER_CLASS.key,
classOf[AlwaysFailOutputCommitter].getName)
@@ -597,8 +595,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
} finally {
// Hadoop 1 doesn't have `Configuration.unset`
- configuration.clear()
- clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue))
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString)
}
}