aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-03-23 22:21:15 -0700
committerAndrew Or <andrew@databricks.com>2016-03-23 22:21:15 -0700
commitc44d140cae99d0b880e6d25f158125ad3adc6a05 (patch)
tree7f0e5324e67efeff2cccf661cd27a21c3618098c /sql/core/src
parentcf823bead18c5be86b36da59b4bbf935c4804d04 (diff)
downloadspark-c44d140cae99d0b880e6d25f158125ad3adc6a05.tar.gz
spark-c44d140cae99d0b880e6d25f158125ad3adc6a05.tar.bz2
spark-c44d140cae99d0b880e6d25f158125ad3adc6a05.zip
Revert "[SPARK-14014][SQL] Replace existing catalog with SessionCatalog"
This reverts commit 5dfc01976bb0d72489620b4f32cc12d620bb6260.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala4
10 files changed, 73 insertions, 105 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index e413e77bc1..853a74c827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,14 +25,13 @@ import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
@@ -66,14 +65,13 @@ class SQLContext private[sql](
@transient val sparkContext: SparkContext,
@transient protected[sql] val cacheManager: CacheManager,
@transient private[sql] val listener: SQLListener,
- val isRootContext: Boolean,
- @transient private[sql] val externalCatalog: ExternalCatalog)
+ val isRootContext: Boolean)
extends Logging with Serializable {
self =>
- def this(sc: SparkContext) = {
- this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog)
+ def this(sparkContext: SparkContext) = {
+ this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true)
}
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
@@ -111,8 +109,7 @@ class SQLContext private[sql](
sparkContext = sparkContext,
cacheManager = cacheManager,
listener = listener,
- isRootContext = false,
- externalCatalog = externalCatalog)
+ isRootContext = false)
}
/**
@@ -189,12 +186,6 @@ class SQLContext private[sql](
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
- // Extract `spark.sql.*` entries and put it in our SQLConf.
- // Subclasses may additionally set these entries in other confs.
- SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) =>
- setConf(k, v)
- }
-
protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)
protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))
@@ -208,6 +199,30 @@ class SQLContext private[sql](
sparkContext.addJar(path)
}
+ {
+ // We extract spark sql settings from SparkContext's conf and put them to
+ // Spark SQL's conf.
+ // First, we populate the SQLConf (conf). So, we can make sure that other values using
+ // those settings in their construction can get the correct settings.
+ // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version
+ // and spark.sql.hive.metastore.jars to get correctly constructed.
+ val properties = new Properties
+ sparkContext.getConf.getAll.foreach {
+ case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value)
+ case _ =>
+ }
+ // We directly put those settings to conf to avoid of calling setConf, which may have
+ // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive
+ // get constructed. If we call setConf directly, the constructed metadataHive may have
+ // wrong settings, or the construction may fail.
+ conf.setConf(properties)
+ // After we have populated SQLConf, we call setConf to populate other confs in the subclass
+ // (e.g. hiveconf in HiveContext).
+ properties.asScala.foreach {
+ case (key, value) => setConf(key, value)
+ }
+ }
+
/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
@@ -668,10 +683,8 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
- sessionState.catalog.createTempTable(
- sessionState.sqlParser.parseTableIdentifier(tableName).table,
- df.logicalPlan,
- ignoreIfExists = true)
+ sessionState.catalog.registerTable(
+ sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
}
/**
@@ -684,7 +697,7 @@ class SQLContext private[sql](
*/
def dropTempTable(tableName: String): Unit = {
cacheManager.tryUncacheQuery(table(tableName))
- sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true)
+ sessionState.catalog.unregisterTable(TableIdentifier(tableName))
}
/**
@@ -811,7 +824,9 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(): Array[String] = {
- tableNames(sessionState.catalog.getCurrentDatabase)
+ sessionState.catalog.getTables(None).map {
+ case (tableName, _) => tableName
+ }.toArray
}
/**
@@ -821,7 +836,9 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(databaseName: String): Array[String] = {
- sessionState.catalog.listTables(databaseName).map(_.table).toArray
+ sessionState.catalog.getTables(Some(databaseName)).map {
+ case (tableName, _) => tableName
+ }.toArray
}
@transient
@@ -1008,18 +1025,4 @@ object SQLContext {
}
sqlListener.get()
}
-
- /**
- * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
- */
- private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = {
- val properties = new Properties
- sparkConf.getAll.foreach { case (key, value) =>
- if (key.startsWith("spark.sql")) {
- properties.setProperty(key, value)
- }
- }
- properties
- }
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 964f0a7a7b..59c3ffcf48 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -339,12 +339,10 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma
override def run(sqlContext: SQLContext): Seq[Row] = {
// Since we need to return a Seq of rows, we will call getTables directly
// instead of calling tables in sqlContext.
- val catalog = sqlContext.sessionState.catalog
- val db = databaseName.getOrElse(catalog.getCurrentDatabase)
- val rows = catalog.listTables(db).map { t =>
- val isTemp = t.database.isEmpty
- Row(t.table, isTemp)
+ val rows = sqlContext.sessionState.catalog.getTables(databaseName).map {
+ case (tableName, isTemporary) => Row(tableName, isTemporary)
}
+
rows
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 24923bbb10..9e8e0352db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -93,21 +93,15 @@ case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {
- if (tableIdent.database.isDefined) {
- throw new AnalysisException(
- s"Temporary table '$tableIdent' should not have specified a database")
- }
-
def run(sqlContext: SQLContext): Seq[Row] = {
val dataSource = DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = provider,
options = options)
- sqlContext.sessionState.catalog.createTempTable(
- tableIdent.table,
- Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
- ignoreIfExists = true)
+ sqlContext.sessionState.catalog.registerTable(
+ tableIdent,
+ Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan)
Seq.empty[Row]
}
@@ -121,11 +115,6 @@ case class CreateTempTableUsingAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
- if (tableIdent.database.isDefined) {
- throw new AnalysisException(
- s"Temporary table '$tableIdent' should not have specified a database")
- }
-
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = Dataset.ofRows(sqlContext, query)
val dataSource = DataSource(
@@ -135,10 +124,9 @@ case class CreateTempTableUsingAsSelect(
bucketSpec = None,
options = options)
val result = dataSource.write(mode, df)
- sqlContext.sessionState.catalog.createTempTable(
- tableIdent.table,
- Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan,
- ignoreIfExists = true)
+ sqlContext.sessionState.catalog.registerTable(
+ tableIdent,
+ Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 28ac4583e9..63f0e4f8c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -19,12 +19,10 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation}
/**
@@ -101,9 +99,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
-private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
- extends (LogicalPlan => Unit) {
-
+private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) }
def apply(plan: LogicalPlan): Unit = {
@@ -143,7 +139,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
}
PartitioningUtils.validatePartitionColumnDataTypes(
- r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis)
+ r.schema, part.keySet.toSeq, catalog.conf.caseSensitiveAnalysis)
// Get all input data source relations of the query.
val srcRelations = query.collect {
@@ -194,7 +190,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
}
PartitioningUtils.validatePartitionColumnDataTypes(
- c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis)
+ c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis)
for {
spec <- c.bucketSpec
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index e5f02caabc..e6be0ab3bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.internal
import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
-import org.apache.spark.sql.catalyst.catalog.SessionCatalog
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog}
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -46,7 +45,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Internal catalog for managing table and database states.
*/
- lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
+ lazy val catalog: Catalog = new SimpleCatalog(conf)
/**
* Internal catalog for managing functions registered by the user.
@@ -69,7 +68,7 @@ private[sql] class SessionState(ctx: SQLContext) {
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
- override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog))
+ override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index bb54c525cb..2820e4fa23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -33,8 +33,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
}
after {
- sqlContext.sessionState.catalog.dropTable(
- TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
+ sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
}
test("get all tables") {
@@ -46,22 +45,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- sqlContext.sessionState.catalog.dropTable(
- TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
+ sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
- test("getting all tables with a database name has no impact on returned table names") {
+ test("getting all Tables with a database name has no impact on returned table names") {
checkAnswer(
- sqlContext.tables("default").filter("tableName = 'ListTablesSuiteTable'"),
+ sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
- sql("show TABLES in default").filter("tableName = 'ListTablesSuiteTable'"),
+ sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- sqlContext.sessionState.catalog.dropTable(
- TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
+ sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index 2f62ad4850..2ad92b52c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
-class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
+class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
object DummyRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
@@ -78,11 +78,4 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
sqlContext.experimental.extraOptimizations = Seq(DummyRule)
assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule))
}
-
- test("SQLContext can access `spark.sql.*` configs") {
- sc.conf.set("spark.sql.with.or.without.you", "my love")
- val sqlContext = new SQLContext(sc)
- assert(sqlContext.getConf("spark.sql.with.or.without.you") == "my love")
- }
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4f36b1b42a..eb486a135f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1397,16 +1397,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-4699 case sensitivity SQL query") {
- val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE)
- try {
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
- val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
- rdd.toDF().registerTempTable("testTable1")
- checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
- } finally {
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig)
- }
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
+ val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
+ val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
+ rdd.toDF().registerTempTable("testTable1")
+ checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, true)
}
test("SPARK-6145: ORDER BY test for nested fields") {
@@ -1680,8 +1676,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
.format("parquet")
.save(path)
- // We don't support creating a temporary table while specifying a database
- intercept[AnalysisException] {
+ val message = intercept[AnalysisException] {
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE db.t
@@ -1691,8 +1686,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|)
""".stripMargin)
}.getMessage
+ assert(message.contains("Specifying database name or other qualifiers are not allowed"))
- // If you use backticks to quote the name then it's OK.
+ // If you use backticks to quote the name of a temporary table having dot in it.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 2f806ebba6..f8166c7ddc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -51,8 +51,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.dropTable(
- TableIdentifier("tmp"), ignoreIfNotExists = true)
+ sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
}
test("overwriting") {
@@ -62,8 +61,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.dropTable(
- TableIdentifier("tmp"), ignoreIfNotExists = true)
+ sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
}
test("self-join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 80a85a6615..d48358566e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -189,8 +189,8 @@ private[sql] trait SQLTestUtils
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
- sqlContext.sessionState.catalog.setCurrentDatabase(db)
- try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default")
+ sqlContext.sql(s"USE $db")
+ try f finally sqlContext.sql(s"USE default")
}
/**