diff options
Diffstat (limited to 'sql/core/src')
10 files changed, 105 insertions, 73 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 853a74c827..e413e77bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,13 +25,14 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} @@ -65,13 +66,14 @@ class SQLContext private[sql]( @transient val sparkContext: SparkContext, @transient protected[sql] val cacheManager: CacheManager, @transient private[sql] val listener: SQLListener, - val isRootContext: Boolean) + val isRootContext: Boolean, + @transient private[sql] val externalCatalog: ExternalCatalog) extends Logging with Serializable { self => - def this(sparkContext: SparkContext) = { - this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) + def this(sc: SparkContext) = { + this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog) } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) @@ -109,7 +111,8 @@ class SQLContext private[sql]( sparkContext = sparkContext, cacheManager = cacheManager, listener = listener, - isRootContext = false) + isRootContext = false, + externalCatalog = externalCatalog) } /** @@ -186,6 +189,12 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + // Extract `spark.sql.*` entries and put it in our SQLConf. + // Subclasses may additionally set these entries in other confs. + SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) => + setConf(k, v) + } + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -199,30 +208,6 @@ class SQLContext private[sql]( sparkContext.addJar(path) } - { - // We extract spark sql settings from SparkContext's conf and put them to - // Spark SQL's conf. - // First, we populate the SQLConf (conf). So, we can make sure that other values using - // those settings in their construction can get the correct settings. - // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version - // and spark.sql.hive.metastore.jars to get correctly constructed. - val properties = new Properties - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) - case _ => - } - // We directly put those settings to conf to avoid of calling setConf, which may have - // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive - // get constructed. If we call setConf directly, the constructed metadataHive may have - // wrong settings, or the construction may fail. - conf.setConf(properties) - // After we have populated SQLConf, we call setConf to populate other confs in the subclass - // (e.g. hiveconf in HiveContext). - properties.asScala.foreach { - case (key, value) => setConf(key, value) - } - } - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -683,8 +668,10 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - sessionState.catalog.registerTable( - sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + sessionState.catalog.createTempTable( + sessionState.sqlParser.parseTableIdentifier(tableName).table, + df.logicalPlan, + ignoreIfExists = true) } /** @@ -697,7 +684,7 @@ class SQLContext private[sql]( */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - sessionState.catalog.unregisterTable(TableIdentifier(tableName)) + sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) } /** @@ -824,9 +811,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(): Array[String] = { - sessionState.catalog.getTables(None).map { - case (tableName, _) => tableName - }.toArray + tableNames(sessionState.catalog.getCurrentDatabase) } /** @@ -836,9 +821,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - sessionState.catalog.getTables(Some(databaseName)).map { - case (tableName, _) => tableName - }.toArray + sessionState.catalog.listTables(databaseName).map(_.table).toArray } @transient @@ -1025,4 +1008,18 @@ object SQLContext { } sqlListener.get() } + + /** + * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. + */ + private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = { + val properties = new Properties + sparkConf.getAll.foreach { case (key, value) => + if (key.startsWith("spark.sql")) { + properties.setProperty(key, value) + } + } + properties + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 59c3ffcf48..964f0a7a7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -339,10 +339,12 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma override def run(sqlContext: SQLContext): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly // instead of calling tables in sqlContext. - val rows = sqlContext.sessionState.catalog.getTables(databaseName).map { - case (tableName, isTemporary) => Row(tableName, isTemporary) + val catalog = sqlContext.sessionState.catalog + val db = databaseName.getOrElse(catalog.getCurrentDatabase) + val rows = catalog.listTables(db).map { t => + val isTemp = t.database.isEmpty + Row(t.table, isTemp) } - rows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 9e8e0352db..24923bbb10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -93,15 +93,21 @@ case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + def run(sqlContext: SQLContext): Seq[Row] = { val dataSource = DataSource( sqlContext, userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sqlContext.sessionState.catalog.registerTable( - tableIdent, - Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan, + ignoreIfExists = true) Seq.empty[Row] } @@ -115,6 +121,11 @@ case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + override def run(sqlContext: SQLContext): Seq[Row] = { val df = Dataset.ofRows(sqlContext, query) val dataSource = DataSource( @@ -124,9 +135,10 @@ case class CreateTempTableUsingAsSelect( bucketSpec = None, options = options) val result = dataSource.write(mode, df) - sqlContext.sessionState.catalog.registerTable( - tableIdent, - Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan, + ignoreIfExists = true) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 63f0e4f8c9..28ac4583e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} /** @@ -99,7 +101,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { +private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) + extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } def apply(plan: LogicalPlan): Unit = { @@ -139,7 +143,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - r.schema, part.keySet.toSeq, catalog.conf.caseSensitiveAnalysis) + r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis) // Get all input data source relations of the query. val srcRelations = query.collect { @@ -190,7 +194,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis) for { spec <- c.bucketSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e6be0ab3bc..e5f02caabc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.internal import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -45,7 +46,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Internal catalog for managing table and database states. */ - lazy val catalog: Catalog = new SimpleCatalog(conf) + lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf) /** * Internal catalog for managing functions registered by the user. @@ -68,7 +69,7 @@ private[sql] class SessionState(ctx: SQLContext) { DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) - override val extendedCheckRules = Seq(datasources.PreWriteCheck(catalog)) + override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2820e4fa23..bb54c525cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,7 +33,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) } test("get all tables") { @@ -45,20 +46,22 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } - test("getting all Tables with a database name has no impact on returned table names") { + test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 2ad92b52c4..2f62ad4850 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ +class SQLContextSuite extends SparkFunSuite with SharedSparkContext { object DummyRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p @@ -78,4 +78,11 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ sqlContext.experimental.extraOptimizations = Seq(DummyRule) assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } + + test("SQLContext can access `spark.sql.*` configs") { + sc.conf.set("spark.sql.with.or.without.you", "my love") + val sqlContext = new SQLContext(sc) + assert(sqlContext.getConf("spark.sql.with.or.without.you") == "my love") + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 077e579931..c958eac266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1476,12 +1476,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) - val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.toDF().registerTempTable("testTable1") - checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) + val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE) + try { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.toDF().registerTempTable("testTable1") + checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + } finally { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig) + } } test("SPARK-6145: ORDER BY test for nested fields") { @@ -1755,7 +1759,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .format("parquet") .save(path) - val message = intercept[AnalysisException] { + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { sqlContext.sql( s""" |CREATE TEMPORARY TABLE db.t @@ -1765,9 +1770,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |) """.stripMargin) }.getMessage - assert(message.contains("Specifying database name or other qualifiers are not allowed")) - // If you use backticks to quote the name of a temporary table having dot in it. + // If you use backticks to quote the name then it's OK. sqlContext.sql( s""" |CREATE TEMPORARY TABLE `db.t` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f8166c7ddc..2f806ebba6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -51,7 +51,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { @@ -61,7 +62,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("self-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d48358566e..80a85a6615 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -189,8 +189,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + sqlContext.sessionState.catalog.setCurrentDatabase(db) + try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") } /** |