diff options
author | Reynold Xin <rxin@databricks.com> | 2016-03-15 10:12:32 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-03-15 10:12:32 -0700 |
commit | 5e6f2f45639f727cba43967858eb95b865ed81fa (patch) | |
tree | 07d6c86cbe877110adb5f88500e43fb64218519a /sql/core | |
parent | 48978abfa4d8f2cf79a4b053cc8bc7254cc2d61b (diff) | |
download | spark-5e6f2f45639f727cba43967858eb95b865ed81fa.tar.gz spark-5e6f2f45639f727cba43967858eb95b865ed81fa.tar.bz2 spark-5e6f2f45639f727cba43967858eb95b865ed81fa.zip |
[SPARK-13893][SQL] Remove SQLContext.catalog/analyzer (internal method)
## What changes were proposed in this pull request?
Our internal code can go through SessionState.catalog and SessionState.analyzer. This brings two small benefits:
1. Reduces internal dependency on SQLContext.
2. Removes 2 public methods in Java (Java does not obey package private visibility).
More importantly, according to the design in SPARK-13485, we'd need to claim this catalog function for the user-facing public functions, rather than having an internal field.
## How was this patch tested?
Existing unit/integration test code.
Author: Reynold Xin <rxin@databricks.com>
Closes #11716 from rxin/SPARK-13893.
Diffstat (limited to 'sql/core')
11 files changed, 40 insertions, 38 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f7be5f6b37..33588ef72f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -155,7 +155,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: Double, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { @@ -182,7 +182,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: String, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { @@ -353,7 +353,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case _: String => StringType } - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { @@ -382,7 +382,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } } - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => v match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 76b8d71ac9..57c978bec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -394,7 +394,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { Dataset.newDataFrame(sqlContext, - sqlContext.catalog.lookupRelation( + sqlContext.sessionState.catalog.lookupRelation( sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index de87f4d7c2..9951f0fabf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -323,7 +323,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ private def normalize(columnName: String, columnType: String): String = { val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) + validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + s"existing columns (${validColumnNames.mkString(", ")})")) } @@ -358,7 +358,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val tableExists = df.sqlContext.catalog.tableExists(tableIdent) + val tableExists = df.sqlContext.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef239a1e2f..f7ef0de21c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -166,15 +166,16 @@ class Dataset[T] private[sql]( private implicit def classTag = unresolvedTEncoder.clsTag protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") - } + queryExecution.analyzed.resolveQuoted(colName, sqlContext.sessionState.analyzer.resolver) + .getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } } protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolveQuoted(n.name, sqlContext.sessionState.analyzer.resolver).get } } @@ -1400,7 +1401,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def withColumn(colName: String, col: Column): DataFrame = { - val resolver = sqlContext.analyzer.resolver + val resolver = sqlContext.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { @@ -1421,7 +1422,7 @@ class Dataset[T] private[sql]( * Returns a new [[DataFrame]] by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - val resolver = sqlContext.analyzer.resolver + val resolver = sqlContext.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { @@ -1445,7 +1446,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { - val resolver = sqlContext.analyzer.resolver + val resolver = sqlContext.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldRename = output.exists(f => resolver(f.name, existingName)) if (shouldRename) { @@ -1480,7 +1481,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { - val resolver = sqlContext.analyzer.resolver + val resolver = sqlContext.sessionState.analyzer.resolver val remainingCols = schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) if (remainingCols.size == this.schema.size) { @@ -1501,7 +1502,8 @@ class Dataset[T] private[sql]( def drop(col: Column): DataFrame = { val expression = col match { case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u) + queryExecution.analyzed.resolveQuoted( + u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0f5d1c8cab..177d78c4c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,8 +120,6 @@ class SQLContext private[sql]( @transient protected[sql] lazy val sessionState: SessionState = new SessionState(self) protected[sql] def conf: SQLConf = sessionState.conf - protected[sql] def catalog: Catalog = sessionState.catalog - protected[sql] def analyzer: Analyzer = sessionState.analyzer /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -699,7 +697,8 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + sessionState.catalog.registerTable( + sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** @@ -712,7 +711,7 @@ class SQLContext private[sql]( */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(TableIdentifier(tableName)) + sessionState.catalog.unregisterTable(TableIdentifier(tableName)) } /** @@ -797,7 +796,7 @@ class SQLContext private[sql]( } private def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.newDataFrame(this, catalog.lookupRelation(tableIdent)) + Dataset.newDataFrame(this, sessionState.catalog.lookupRelation(tableIdent)) } /** @@ -839,7 +838,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(): Array[String] = { - catalog.getTables(None).map { + sessionState.catalog.getTables(None).map { case (tableName, _) => tableName }.toArray } @@ -851,7 +850,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - catalog.getTables(Some(databaseName)).map { + sessionState.catalog.getTables(Some(databaseName)).map { case (tableName, _) => tableName }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5b4254f741..912b84abc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,14 +31,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch { + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) ae.setStackTrace(e.getStackTrace) throw ae } - lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) + lazy val analyzed: LogicalPlan = sqlContext.sessionState.analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index e711797c1b..44b07e4613 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -330,7 +330,7 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma override def run(sqlContext: SQLContext): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly // instead of calling tables in sqlContext. - val rows = sqlContext.catalog.getTables(databaseName).map { + val rows = sqlContext.sessionState.catalog.getTables(databaseName).map { case (tableName, isTemporary) => Row(tableName, isTemporary) } @@ -417,7 +417,7 @@ case class DescribeFunction( case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.catalog.setCurrentDatabase(databaseName) + sqlContext.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index ef95d5d289..84e98c0f9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -67,7 +67,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val filterSet = ExpressionSet(filters) val partitionColumns = - AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver)) + AttributeSet( + l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver)) val partitionKeyFilters = ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 04e51735c4..7ca0e8859a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -99,7 +99,7 @@ case class CreateTempTableUsing( userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sqlContext.catalog.registerTable( + sqlContext.sessionState.catalog.registerTable( tableIdent, Dataset.newDataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan) @@ -124,7 +124,7 @@ case class CreateTempTableUsingAsSelect( bucketSpec = None, options = options) val result = dataSource.write(mode, df) - sqlContext.catalog.registerTable( + sqlContext.sessionState.catalog.registerTable( tableIdent, Dataset.newDataFrame(sqlContext, LogicalRelation(result)).logicalPlan) @@ -137,11 +137,11 @@ case class RefreshTable(tableIdent: TableIdentifier) override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(tableIdent) + sqlContext.sessionState.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent) + val logicalPlan = sqlContext.sessionState.catalog.lookupRelation(tableIdent) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 3d7c576965..2820e4fa23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -33,7 +33,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) } test("get all tables") { @@ -45,7 +45,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } @@ -58,7 +58,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index acfc1a518a..fb99b0c7e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -51,7 +51,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) } test("overwriting") { @@ -61,7 +61,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp")) } test("self-join") { |