aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala4
10 files changed, 105 insertions, 73 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 853a74c827..e413e77bc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,13 +25,14 @@ import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
@@ -65,13 +66,14 @@ class SQLContext private[sql](
@transient val sparkContext: SparkContext,
@transient protected[sql] val cacheManager: CacheManager,
@transient private[sql] val listener: SQLListener,
- val isRootContext: Boolean)
+ val isRootContext: Boolean,
+ @transient private[sql] val externalCatalog: ExternalCatalog)
extends Logging with Serializable {
self =>
- def this(sparkContext: SparkContext) = {
- this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true)
+ def this(sc: SparkContext) = {
+ this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog)
}
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
@@ -109,7 +111,8 @@ class SQLContext private[sql](
sparkContext = sparkContext,
cacheManager = cacheManager,
listener = listener,
- isRootContext = false)
+ isRootContext = false,
+ externalCatalog = externalCatalog)
}
/**
@@ -186,6 +189,12 @@ class SQLContext private[sql](
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
+ // Extract `spark.sql.*` entries and put it in our SQLConf.
+ // Subclasses may additionally set these entries in other confs.
+ SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) =>
+ setConf(k, v)
+ }
+
protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)
protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))
@@ -199,30 +208,6 @@ class SQLContext private[sql](
sparkContext.addJar(path)
}
- {
- // We extract spark sql settings from SparkContext's conf and put them to
- // Spark SQL's conf.
- // First, we populate the SQLConf (conf). So, we can make sure that other values using
- // those settings in their construction can get the correct settings.
- // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version
- // and spark.sql.hive.metastore.jars to get correctly constructed.
- val properties = new Properties
- sparkContext.getConf.getAll.foreach {
- case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value)
- case _ =>
- }
- // We directly put those settings to conf to avoid of calling setConf, which may have
- // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive
- // get constructed. If we call setConf directly, the constructed metadataHive may have
- // wrong settings, or the construction may fail.
- conf.setConf(properties)
- // After we have populated SQLConf, we call setConf to populate other confs in the subclass
- // (e.g. hiveconf in HiveContext).
- properties.asScala.foreach {
- case (key, value) => setConf(key, value)
- }
- }
-
/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
@@ -683,8 +668,10 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
- sessionState.catalog.registerTable(
- sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
+ sessionState.catalog.createTempTable(
+ sessionState.sqlParser.parseTableIdentifier(tableName).table,
+ df.logicalPlan,
+ ignoreIfExists = true)
}
/**
@@ -697,7 +684,7 @@ class SQLContext private[sql](
*/
def dropTempTable(tableName: String): Unit = {
cacheManager.tryUncacheQuery(table(tableName))
- sessionState.catalog.unregisterTable(TableIdentifier(tableName))
+ sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true)
}
/**
@@ -824,9 +811,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(): Array[String] = {
- sessionState.catalog.getTables(None).map {
- case (tableName, _) => tableName
- }.toArray
+ tableNames(sessionState.catalog.getCurrentDatabase)
}
/**
@@ -836,9 +821,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(databaseName: String): Array[String] = {
- sessionState.catalog.getTables(Some(databaseName)).map {
- case (tableName, _) => tableName
- }.toArray
+ sessionState.catalog.listTables(databaseName).map(_.table).toArray
}
@transient
@@ -1025,4 +1008,18 @@ object SQLContext {
}
sqlListener.get()
}
+
+ /**
+ * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
+ */
+ private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = {
+ val properties = new Properties
+ sparkConf.getAll.foreach { case (key, value) =>
+ if (key.startsWith("spark.sql")) {
+ properties.setProperty(key, value)
+ }
+ }
+ properties
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 59c3ffcf48..964f0a7a7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -339,10 +339,12 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma
override def run(sqlContext: SQLContext): Seq[Row] = {
// Since we need to return a Seq of rows, we will call getTables directly
// instead of calling tables in sqlContext.
- val rows = sqlContext.sessionState.catalog.getTables(databaseName).map {
- case (tableName, isTemporary) => Row(tableName, isTemporary)
+ val catalog = sqlContext.sessionState.catalog
+ val db = databaseName.getOrElse(catalog.getCurrentDatabase)
+ val rows = catalog.listTables(db).map { t =>
+ val isTemp = t.database.isEmpty
+ Row(t.table, isTemp)
}
-
rows
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 9e8e0352db..24923bbb10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -93,15 +93,21 @@ case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {
+ if (tableIdent.database.isDefined) {
+ throw new AnalysisException(
+ s"Temporary table '$tableIdent' should not have specified a database")
+ }
+
def run(sqlContext: SQLContext): Seq[Row] = {
val dataSource = DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = provider,
options = options)
- sqlContext.sessionState.catalog.registerTable(
- tableIdent,
- Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan)
+ sqlContext.sessionState.catalog.createTempTable(
+ tableIdent.table,
+ Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
+ ignoreIfExists = true)
Seq.empty[Row]
}
@@ -115,6 +121,11 @@ case class CreateTempTableUsingAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
+ if (tableIdent.database.isDefined) {
+ throw new AnalysisException(
+ s"Temporary table '$tableIdent' should not have specified a database")
+ }
+
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = Dataset.ofRows(sqlContext, query)
val dataSource = DataSource(
@@ -124,9 +135,10 @@ case class CreateTempTableUsingAsSelect(
bucketSpec = None,
options = options)
val result = dataSource.write(mode, df)
- sqlContext.sessionState.catalog.registerTable(
- tableIdent,
- Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan)
+ sqlContext.sessionState.catalog.createTempTable(
+ tableIdent.table,
+ Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan,
+ ignoreIfExists = true)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 63f0e4f8c9..28ac4583e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation}
/**
@@ -99,7 +101,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
/**
* A rule to do various checks before inserting into or writing to a data source table.
*/
-private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
+private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
+ extends (LogicalPlan => Unit) {
+
def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) }
def apply(plan: LogicalPlan): Unit = {
@@ -139,7 +143,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
PartitioningUtils.validatePartitionColumnDataTypes(
- r.schema, part.keySet.toSeq, catalog.conf.caseSensitiveAnalysis)
+ r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis)
// Get all input data source relations of the query.
val srcRelations = query.collect {
@@ -190,7 +194,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
PartitioningUtils.validatePartitionColumnDataTypes(
- c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis)
+ c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis)
for {
spec <- c.bucketSpec
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index e6be0ab3bc..e5f02caabc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.internal
import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -45,7 +46,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Internal catalog for managing table and database states.
*/
- lazy val catalog: Catalog = new SimpleCatalog(conf)
+ lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
/**
* Internal catalog for managing functions registered by the user.
@@ -68,7 +69,7 @@ private[sql] class SessionState(ctx: SQLContext) {
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
- override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog))
+ override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 2820e4fa23..bb54c525cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -33,7 +33,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
}
after {
- sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
+ sqlContext.sessionState.catalog.dropTable(
+ TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
}
test("get all tables") {
@@ -45,20 +46,22 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
+ sqlContext.sessionState.catalog.dropTable(
+ TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
- test("getting all Tables with a database name has no impact on returned table names") {
+ test("getting all tables with a database name has no impact on returned table names") {
checkAnswer(
- sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+ sqlContext.tables("default").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
checkAnswer(
- sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
+ sql("show TABLES in default").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
- sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
+ sqlContext.sessionState.catalog.dropTable(
+ TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true)
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index 2ad92b52c4..2f62ad4850 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
-class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
+class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
object DummyRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
@@ -78,4 +78,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{
sqlContext.experimental.extraOptimizations = Seq(DummyRule)
assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule))
}
+
+ test("SQLContext can access `spark.sql.*` configs") {
+ sc.conf.set("spark.sql.with.or.without.you", "my love")
+ val sqlContext = new SQLContext(sc)
+ assert(sqlContext.getConf("spark.sql.with.or.without.you") == "my love")
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 077e579931..c958eac266 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1476,12 +1476,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-4699 case sensitivity SQL query") {
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
- val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
- val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
- rdd.toDF().registerTempTable("testTable1")
- checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
- sqlContext.setConf(SQLConf.CASE_SENSITIVE, true)
+ val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE)
+ try {
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
+ val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
+ val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
+ rdd.toDF().registerTempTable("testTable1")
+ checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
+ } finally {
+ sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig)
+ }
}
test("SPARK-6145: ORDER BY test for nested fields") {
@@ -1755,7 +1759,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
.format("parquet")
.save(path)
- val message = intercept[AnalysisException] {
+ // We don't support creating a temporary table while specifying a database
+ intercept[AnalysisException] {
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE db.t
@@ -1765,9 +1770,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|)
""".stripMargin)
}.getMessage
- assert(message.contains("Specifying database name or other qualifiers are not allowed"))
- // If you use backticks to quote the name of a temporary table having dot in it.
+ // If you use backticks to quote the name then it's OK.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index f8166c7ddc..2f806ebba6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -51,7 +51,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
+ sqlContext.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"), ignoreIfNotExists = true)
}
test("overwriting") {
@@ -61,7 +62,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
- sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
+ sqlContext.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"), ignoreIfNotExists = true)
}
test("self-join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d48358566e..80a85a6615 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -189,8 +189,8 @@ private[sql] trait SQLTestUtils
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
- sqlContext.sql(s"USE $db")
- try f finally sqlContext.sql(s"USE default")
+ sqlContext.sessionState.catalog.setCurrentDatabase(db)
+ try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default")
}
/**