From 18c2c92580bdc27aa5129d9e7abda418a3633ea6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 25 Apr 2016 20:54:31 -0700 Subject: [SPARK-14861][SQL] Replace internal usages of SQLContext with SparkSession ## What changes were proposed in this pull request? In Spark 2.0, `SparkSession` is the new thing. Internally we should stop using `SQLContext` everywhere since that's supposed to be not the main user-facing API anymore. In this patch I took care to not break any public APIs. The one place that's suspect is `o.a.s.ml.source.libsvm.DefaultSource`, but according to mengxr it's not supposed to be public so it's OK to change the underlying `FileFormat` trait. **Reviewers**: This is a big patch that may be difficult to review but the changes are actually really straightforward. If you prefer I can break it up into a few smaller patches, but it will delay the progress of this issue a little. ## How was this patch tested? No change in functionality intended. Author: Andrew Or Closes #12625 from andrewor14/spark-session-refactor. --- .../org/apache/spark/sql/ContinuousQuery.scala | 4 +- .../apache/spark/sql/ContinuousQueryManager.scala | 12 +-- .../apache/spark/sql/DataFrameNaFunctions.scala | 8 +- .../org/apache/spark/sql/DataFrameReader.scala | 38 +++---- .../org/apache/spark/sql/DataFrameWriter.scala | 30 +++--- .../main/scala/org/apache/spark/sql/Dataset.scala | 112 +++++++++++---------- .../apache/spark/sql/KeyValueGroupedDataset.scala | 12 +-- .../spark/sql/RelationalGroupedDataset.scala | 12 +-- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../scala/org/apache/spark/sql/SparkSession.scala | 47 ++++----- .../apache/spark/sql/execution/CacheManager.scala | 8 +- .../apache/spark/sql/execution/ExistingRDD.scala | 15 +-- .../spark/sql/execution/QueryExecution.scala | 42 ++++---- .../apache/spark/sql/execution/SQLExecution.scala | 11 +- .../org/apache/spark/sql/execution/SparkPlan.scala | 4 +- .../spark/sql/execution/command/AnalyzeTable.scala | 8 +- .../sql/execution/command/HiveNativeCommand.scala | 6 +- .../spark/sql/execution/command/SetCommand.scala | 50 ++++----- .../apache/spark/sql/execution/command/cache.scala | 18 ++-- .../spark/sql/execution/command/commands.scala | 12 +-- .../execution/command/createDataSourceTables.scala | 38 +++---- .../spark/sql/execution/command/databases.scala | 12 +-- .../apache/spark/sql/execution/command/ddl.scala | 56 +++++------ .../spark/sql/execution/command/functions.scala | 20 ++-- .../spark/sql/execution/command/resources.scala | 10 +- .../spark/sql/execution/command/tables.scala | 36 +++---- .../apache/spark/sql/execution/command/views.scala | 20 ++-- .../sql/execution/datasources/DataSource.scala | 56 +++++------ .../sql/execution/datasources/FileScanRDD.scala | 6 +- .../execution/datasources/FileSourceStrategy.scala | 16 +-- .../datasources/InsertIntoDataSource.scala | 8 +- .../datasources/InsertIntoHadoopFsRelation.scala | 16 +-- .../execution/datasources/WriterContainer.scala | 7 +- .../execution/datasources/csv/DefaultSource.scala | 30 +++--- .../spark/sql/execution/datasources/ddl.scala | 32 +++--- .../datasources/fileSourceInterfaces.scala | 30 +++--- .../execution/datasources/jdbc/DefaultSource.scala | 2 +- .../execution/datasources/jdbc/JDBCRelation.scala | 8 +- .../execution/datasources/json/JSONRelation.scala | 26 ++--- .../datasources/parquet/ParquetRelation.scala | 61 +++++------ .../spark/sql/execution/datasources/rules.scala | 6 +- .../execution/datasources/text/DefaultSource.scala | 12 +-- .../spark/sql/execution/stat/FrequentItems.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../sql/execution/streaming/FileStreamSink.scala | 10 +- .../execution/streaming/FileStreamSinkLog.scala | 12 +-- .../sql/execution/streaming/FileStreamSource.scala | 10 +- .../sql/execution/streaming/HDFSMetadataLog.scala | 8 +- .../execution/streaming/IncrementalExecution.scala | 13 +-- .../sql/execution/streaming/StreamExecution.scala | 17 ++-- .../execution/streaming/StreamFileCatalog.scala | 8 +- .../spark/sql/execution/streaming/memory.scala | 4 +- .../org/apache/spark/sql/execution/subquery.scala | 6 +- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/sql/internal/SessionState.scala | 34 ++++--- .../spark/sql/DataFrameNaFunctionsSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 7 +- .../scala/org/apache/spark/sql/QueryTest.scala | 2 +- .../scala/org/apache/spark/sql/StreamTest.scala | 1 - .../apache/spark/sql/execution/SparkPlanTest.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 10 +- .../sql/execution/datasources/json/JsonSuite.scala | 4 +- .../streaming/FileStreamSinkLogSuite.scala | 2 +- .../execution/streaming/HDFSMetadataLogSuite.scala | 15 +-- .../apache/spark/sql/sources/DDLTestSuite.scala | 14 ++- .../spark/sql/sources/FilteredScanSuite.scala | 8 +- .../apache/spark/sql/sources/PrunedScanSuite.scala | 8 +- .../sql/sources/ResolvedDataSourceSuite.scala | 2 +- .../apache/spark/sql/sources/TableScanSuite.scala | 20 ++-- .../org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../org/apache/spark/sql/test/TestSQLContext.scala | 4 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 47 +++++---- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 6 +- .../apache/spark/sql/hive/HiveSessionState.scala | 19 ++-- .../org/apache/spark/sql/hive/HiveStrategies.scala | 4 +- .../apache/spark/sql/hive/MetastoreRelation.scala | 12 +-- .../org/apache/spark/sql/hive/TableReader.scala | 18 ++-- .../sql/hive/execution/CreateTableAsSelect.scala | 12 +-- .../sql/hive/execution/HiveTableScanExec.scala | 8 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 29 +++--- .../org/apache/spark/sql/hive/test/TestHive.scala | 12 ++- .../apache/spark/sql/catalyst/SQLBuilderTest.scala | 2 +- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 6 +- .../sql/hive/execution/AggregationQuerySuite.scala | 2 +- .../sql/sources/CommitFailureTestSource.scala | 4 +- .../spark/sql/sources/SimpleTextRelation.scala | 12 +-- 86 files changed, 719 insertions(+), 664 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index 953169b636..4d5afe2eb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -35,10 +35,10 @@ trait ContinuousQuery { def name: String /** - * Returns the SQLContext associated with `this` query + * Returns the [[SparkSession]] associated with `this`. * @since 2.0.0 */ - def sqlContext: SQLContext + def sparkSession: SparkSession /** * Whether the query is currently active or not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 39d04ed8c2..9e2e2d0bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -29,16 +29,16 @@ import org.apache.spark.sql.util.ContinuousQueryListener /** * :: Experimental :: * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active - * on a [[SQLContext]]. + * on a [[SparkSession]]. * * @since 2.0.0 */ @Experimental -class ContinuousQueryManager(sqlContext: SQLContext) { +class ContinuousQueryManager(sparkSession: SparkSession) { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env) - private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) + private val listenerBus = new ContinuousQueryListenerBus(sparkSession.sparkContext.listenerBus) private val activeQueries = new mutable.HashMap[String, ContinuousQuery] private val activeQueriesLock = new Object private val awaitTerminationLock = new Object @@ -184,7 +184,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) { val analyzedPlan = df.queryExecution.analyzed df.queryExecution.assertAnalyzed() - if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + if (sparkSession.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) } @@ -201,7 +201,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) { StreamingExecutionRelation(source, output) } val query = new StreamExecution( - sqlContext, + sparkSession, name, checkpointLocation, logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f0e16eefc7..ad00966a91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -155,7 +155,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: Double, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.sessionState.analyzer.resolver + val columnEquals = df.sparkSession.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { @@ -182,7 +182,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: String, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.sessionState.analyzer.resolver + val columnEquals = df.sparkSession.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { @@ -355,7 +355,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case _: String => StringType } - val columnEquals = df.sqlContext.sessionState.analyzer.resolver + val columnEquals = df.sparkSession.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { @@ -384,7 +384,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } } - val columnEquals = df.sqlContext.sessionState.analyzer.resolver + val columnEquals = df.sparkSession.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => v match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 15f2344df6..b49cda3f15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -36,12 +36,12 @@ import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this. + * key-value stores, etc) or data streams. Use [[SparkSession.read]] to access this. * * @since 1.4.0 */ @Experimental -class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { +class DataFrameReader protected[sql](sparkSession: SparkSession) extends Logging { /** * Specifies the input data source format. @@ -125,11 +125,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { def load(): DataFrame = { val dataSource = DataSource( - sqlContext, + sparkSession, userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())) + Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())) } /** @@ -151,11 +151,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { @scala.annotation.varargs def load(paths: String*): DataFrame = { if (paths.isEmpty) { - sqlContext.emptyDataFrame + sparkSession.emptyDataFrame } else { - sqlContext.baseRelationToDataFrame( + sparkSession.baseRelationToDataFrame( DataSource.apply( - sqlContext, + sparkSession, paths = paths, userSpecifiedSchema = userSpecifiedSchema, className = source, @@ -172,11 +172,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { def stream(): DataFrame = { val dataSource = DataSource( - sqlContext, + sparkSession, userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.ofRows(sqlContext, StreamingRelation(dataSource)) + Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) } /** @@ -271,8 +271,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val relation = JDBCRelation(url, table, parts, props)(sqlContext) - sqlContext.baseRelationToDataFrame(relation) + val relation = JDBCRelation(url, table, parts, props)(sparkSession) + sparkSession.baseRelationToDataFrame(relation) } /** @@ -368,7 +368,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { val parsedOptions: JSONOptions = new JSONOptions(extraOptions.toMap) val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val schema = userSpecifiedSchema.getOrElse { InferSchema.infer( jsonRDD, @@ -377,14 +377,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } Dataset.ofRows( - sqlContext, + sparkSession, LogicalRDD( schema.toAttributes, JacksonParser.parse( jsonRDD, schema, columnNameOfCorruptRecord, - parsedOptions))(sqlContext)) + parsedOptions))(sparkSession)) } /** @@ -424,9 +424,9 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - Dataset.ofRows(sqlContext, - sqlContext.sessionState.catalog.lookupRelation( - sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) + Dataset.ofRows(sparkSession, + sparkSession.sessionState.catalog.lookupRelation( + sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** @@ -447,14 +447,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ @scala.annotation.varargs def text(paths: String*): Dataset[String] = { - format("text").load(paths : _*).as[String](sqlContext.implicits.newStringEncoder) + format("text").load(paths : _*).as[String](sparkSession.implicits.newStringEncoder) } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// - private var source: String = sqlContext.conf.defaultDataSourceName + private var source: String = sparkSession.sessionState.conf.defaultDataSourceName private var userSpecifiedSchema: Option[StructType] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 99d92b9257..c0811f6a4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -237,7 +237,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { def save(): Unit = { assertNotBucketed() val dataSource = DataSource( - df.sqlContext, + df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec, @@ -284,7 +284,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { new Path(userSpecified).toUri.toString }.orElse { val checkpointConfig: Option[String] = - df.sqlContext.conf.getConf( + df.sparkSession.getConf( SQLConf.CHECKPOINT_LOCATION, None) @@ -297,7 +297,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // If offsets have already been created, we trying to resume a query. val checkpointPath = new Path(checkpointLocation, "offsets") - val fs = checkpointPath.getFileSystem(df.sqlContext.sessionState.hadoopConf) + val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.hadoopConf) if (fs.exists(checkpointPath)) { throw new AnalysisException( s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") @@ -306,9 +306,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { } val sink = new MemorySink(df.schema) - val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink)) + val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) resultDf.registerTempTable(queryName) - val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery( + val continuousQuery = df.sparkSession.sessionState.continuousQueryManager.startQuery( queryName, checkpointLocation, df, @@ -318,16 +318,16 @@ final class DataFrameWriter private[sql](df: DataFrame) { } else { val dataSource = DataSource( - df.sqlContext, + df.sparkSession, className = source, options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { - new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + new Path(df.sparkSession.sessionState.conf.checkpointLocation, queryName).toUri.toString }) - df.sqlContext.sessionState.continuousQueryManager.startQuery( + df.sparkSession.sessionState.continuousQueryManager.startQuery( queryName, checkpointLocation, df, @@ -345,7 +345,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) + insertInto(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -363,7 +363,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { Project(inputDataCols ++ inputPartCols, df.logicalPlan) }.getOrElse(df.logicalPlan) - df.sqlContext.executePlan( + df.sparkSession.executePlan( InsertIntoTable( UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), @@ -413,7 +413,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ private def normalize(columnName: String, columnType: String): String = { val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName)) + validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + s"existing columns (${validColumnNames.mkString(", ")})")) } @@ -444,11 +444,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val tableExists = df.sqlContext.sessionState.catalog.tableExists(tableIdent) + val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => @@ -468,7 +468,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { mode, extraOptions.toMap, df.logicalPlan) - df.sqlContext.executePlan(cmd).toRdd + df.sparkSession.executePlan(cmd).toRdd } } @@ -620,7 +620,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// - private var source: String = df.sqlContext.conf.defaultDataSourceName + private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName private var mode: SaveMode = SaveMode.ErrorIfExists diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c708cbf29..b3064fd531 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,20 +46,19 @@ import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils private[sql] object Dataset { - def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = { - new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]]) + def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { + new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) } - def ofRows(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { - val qe = sqlContext.executePlan(logicalPlan) + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + val qe = sparkSession.executePlan(logicalPlan) qe.assertAnalyzed() - new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](sparkSession, logicalPlan, RowEncoder(qe.analyzed.schema)) } } @@ -90,8 +89,8 @@ private[sql] object Dataset { * There are typically two ways to create a Dataset. The most common way is by pointing Spark * to some files on storage systems, using the `read` function available on a `SparkSession`. * {{{ - * val people = session.read.parquet("...").as[Person] // Scala - * Dataset people = session.read().parquet("...").as(Encoders.bean(Person.class) // Java + * val people = spark.read.parquet("...").as[Person] // Scala + * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class) // Java * }}} * * Datasets can also be created through transformations available on existing Datasets. For example, @@ -121,8 +120,8 @@ private[sql] object Dataset { * A more concrete example in Scala: * {{{ * // To create Dataset[Row] using SQLContext - * val people = session.read.parquet("...") - * val department = session.read.parquet("...") + * val people = spark.read.parquet("...") + * val department = spark.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -133,8 +132,8 @@ private[sql] object Dataset { * and in Java: * {{{ * // To create Dataset using SQLContext - * Dataset people = session.read().parquet("..."); - * Dataset department = session.read().parquet("..."); + * Dataset people = spark.read().parquet("..."); + * Dataset department = spark.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -152,8 +151,8 @@ private[sql] object Dataset { * * @since 1.6.0 */ -class Dataset[T] private[sql]( - @transient val sqlContext: SQLContext, +class Dataset[T] protected[sql]( + @transient val sparkSession: SparkSession, @DeveloperApi @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -163,8 +162,12 @@ class Dataset[T] private[sql]( // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. + def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { + this(sparkSession, sparkSession.executePlan(logicalPlan), encoder) + } + def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { - this(sqlContext, sqlContext.executePlan(logicalPlan), encoder) + this(sqlContext.sparkSession, logicalPlan, encoder) } @transient protected[sql] val logicalPlan: LogicalPlan = { @@ -179,9 +182,9 @@ class Dataset[T] private[sql]( // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. case p if hasSideEffects(p) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) case Union(children) if children.forall(hasSideEffects) => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession) case _ => queryExecution.analyzed } @@ -207,8 +210,10 @@ class Dataset[T] private[sql]( private implicit def classTag = unresolvedTEncoder.clsTag + def sqlContext: SQLContext = sparkSession.wrapped + protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolveQuoted(colName, sqlContext.sessionState.analyzer.resolver) + queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) .getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") @@ -217,7 +222,7 @@ class Dataset[T] private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sqlContext.sessionState.analyzer.resolver).get + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get } } @@ -333,7 +338,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) /** * :: Experimental :: @@ -353,7 +358,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -407,7 +412,7 @@ class Dataset[T] private[sql]( */ def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + sparkSession.executePlan(explain).executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println @@ -631,7 +636,7 @@ class Dataset[T] private[sql]( def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. - val joined = sqlContext.executePlan( + val joined = sparkSession.executePlan( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] @@ -695,7 +700,7 @@ class Dataset[T] private[sql]( .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. - if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { + if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { return withPlan(plan) } @@ -747,7 +752,7 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan - val joined = sqlContext.executePlan(Join(left, right, joinType = + val joined = sparkSession.executePlan(Join(left, right, joinType = JoinType(joinType), Some(condition.expr))) val leftOutput = joined.analyzed.output.take(left.output.length) val rightOutput = joined.analyzed.output.takeRight(right.output.length) @@ -968,7 +973,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) + Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) }: _*) } @@ -987,7 +992,7 @@ class Dataset[T] private[sql]( @Experimental def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { new Dataset[U1]( - sqlContext, + sparkSession, Project( c1.withInputType( unresolvedTEncoder.deserializer, @@ -1005,9 +1010,8 @@ class Dataset[T] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) - val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) - - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } /** @@ -1091,7 +1095,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) + filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -1117,7 +1121,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def where(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) + filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -1254,7 +1258,7 @@ class Dataset[T] private[sql]( def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) - val executed = sqlContext.executePlan(withGroupingKey) + val executed = sparkSession.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], @@ -1507,7 +1511,7 @@ class Dataset[T] private[sql]( val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) }.toArray } @@ -1630,7 +1634,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def withColumn(colName: String, col: Column): DataFrame = { - val resolver = sqlContext.sessionState.analyzer.resolver + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { @@ -1651,7 +1655,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - val resolver = sqlContext.sessionState.analyzer.resolver + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { @@ -1676,7 +1680,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { - val resolver = sqlContext.sessionState.analyzer.resolver + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output val shouldRename = output.exists(f => resolver(f.name, existingName)) if (shouldRename) { @@ -1713,7 +1717,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { - val resolver = sqlContext.sessionState.analyzer.resolver + val resolver = sparkSession.sessionState.analyzer.resolver val remainingCols = schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) if (remainingCols.size == this.schema.size) { @@ -1736,7 +1740,7 @@ class Dataset[T] private[sql]( val expression = col match { case Column(u: UnresolvedAttribute) => queryExecution.analyzed.resolveQuoted( - u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u) + u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } val attrs = this.logicalPlan.output @@ -1957,7 +1961,7 @@ class Dataset[T] private[sql]( @Experimental def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( - sqlContext, + sparkSession, MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2203,7 +2207,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) + sparkSession.cacheManager.cacheQuery(this) this } @@ -2225,7 +2229,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) + sparkSession.cacheManager.cacheQuery(this, None, newLevel) this } @@ -2238,7 +2242,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) + sparkSession.cacheManager.tryUncacheQuery(this, blocking) this } @@ -2259,7 +2263,7 @@ class Dataset[T] private[sql]( lazy val rdd: RDD[T] = { val objectType = unresolvedTEncoder.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) - sqlContext.executePlan(deserialized).toRdd.mapPartitions { rows => + sparkSession.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } } @@ -2286,7 +2290,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def registerTempTable(tableName: String): Unit = { - sqlContext.registerDataFrameAsTable(toDF(), tableName) + sparkSession.registerDataFrameAsTable(toDF(), tableName) } /** @@ -2327,8 +2331,8 @@ class Dataset[T] private[sql]( } } } - import sqlContext.implicits.newStringEncoder - sqlContext.createDataset(rdd) + import sparkSession.implicits.newStringEncoder + sparkSession.createDataset(rdd) } /** @@ -2383,7 +2387,7 @@ class Dataset[T] private[sql]( * an execution. */ private[sql] def withNewExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) + SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } /** @@ -2398,11 +2402,11 @@ class Dataset[T] private[sql]( val start = System.nanoTime() val result = action(df) val end = System.nanoTime() - sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) + sparkSession.listenerManager.onSuccess(name, df.queryExecution, end - start) result } catch { case e: Exception => - sqlContext.listenerManager.onFailure(name, df.queryExecution, e) + sparkSession.listenerManager.onFailure(name, df.queryExecution, e) throw e } } @@ -2415,11 +2419,11 @@ class Dataset[T] private[sql]( val start = System.nanoTime() val result = action(ds) val end = System.nanoTime() - sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start) + sparkSession.listenerManager.onSuccess(name, ds.queryExecution, end - start) result } catch { case e: Exception => - sqlContext.listenerManager.onFailure(name, ds.queryExecution, e) + sparkSession.listenerManager.onFailure(name, ds.queryExecution, e) throw e } } @@ -2440,11 +2444,11 @@ class Dataset[T] private[sql]( /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext, logicalPlan) + Dataset.ofRows(sparkSession, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { - Dataset(sqlContext, logicalPlan) + Dataset(sparkSession, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 05e13e66d1..3a5ea19b8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -55,7 +55,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed - private def sqlContext = queryExecution.sqlContext + private def sparkSession = queryExecution.sparkSession /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -79,7 +79,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ def keys: Dataset[K] = { Dataset[K]( - sqlContext, + sparkSession, Distinct( Project(groupingAttributes, logicalPlan))) } @@ -104,7 +104,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { Dataset[U]( - sqlContext, + sparkSession, MapGroups( f, groupingAttributes, @@ -217,10 +217,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) - val execution = new QueryExecution(sqlContext, aggregate) + val execution = new QueryExecution(sparkSession, aggregate) new Dataset( - sqlContext, + sparkSession, execution, ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) } @@ -289,7 +289,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit val uEncoder = other.unresolvedVEncoder Dataset[R]( - sqlContext, + sparkSession, CoGroup( f, this.groupingAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0ffb136c24..7ee9732fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -42,7 +42,7 @@ class RelationalGroupedDataset protected[sql]( groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { - val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { groupingExprs ++ aggExprs } else { aggExprs @@ -53,17 +53,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -302,7 +302,7 @@ class RelationalGroupedDataset protected[sql]( */ def pivot(pivotColumn: String): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + val maxValues = df.sparkSession.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dde139608a..47c043a00d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -58,7 +58,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @since 1.0.0 */ class SQLContext private[sql]( - @transient private val sparkSession: SparkSession, + val sparkSession: SparkSession, val isRootContext: Boolean) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 00256bd0be..a0f0bd3f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -52,7 +52,8 @@ import org.apache.spark.util.Utils */ class SparkSession private( @transient val sparkContext: SparkContext, - @transient private val existingSharedState: Option[SharedState]) { self => + @transient private val existingSharedState: Option[SharedState]) + extends Serializable { self => def this(sc: SparkContext) { this(sc, None) @@ -81,9 +82,9 @@ class SparkSession private( */ @transient protected[sql] lazy val sessionState: SessionState = { - SparkSession.reflect[SessionState, SQLContext]( + SparkSession.reflect[SessionState, SparkSession]( SparkSession.sessionStateClassName(sparkContext.conf), - new SQLContext(self, isRootContext = false)) + self) } /** @@ -358,7 +359,7 @@ class SparkSession private( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) - Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRDD)(wrapped)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self)) } /** @@ -373,7 +374,7 @@ class SparkSession private( SQLContext.setActive(wrapped) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - Dataset.ofRows(wrapped, LocalRelation.fromProduct(attributeSeq, data)) + Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } /** @@ -438,7 +439,7 @@ class SparkSession private( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { - Dataset.ofRows(wrapped, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } /** @@ -458,7 +459,7 @@ class SparkSession private( val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) } - Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRdd)(wrapped)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) } /** @@ -486,7 +487,7 @@ class SparkSession private( val attrSeq = getSchema(beanClass) val beanInfo = Introspector.getBeanInfo(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - Dataset.ofRows(wrapped, LocalRelation(attrSeq, rows.toSeq)) + Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } /** @@ -496,7 +497,7 @@ class SparkSession private( * @since 2.0.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - Dataset.ofRows(wrapped, LogicalRelation(baseRelation)) + Dataset.ofRows(self, LogicalRelation(baseRelation)) } def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { @@ -504,15 +505,15 @@ class SparkSession private( val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) - Dataset[T](wrapped, plan) + Dataset[T](self, plan) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d)) - val plan = LogicalRDD(attributes, encoded)(wrapped) - Dataset[T](wrapped, plan) + val plan = LogicalRDD(attributes, encoded)(self) + Dataset[T](self, plan) } def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { @@ -567,7 +568,7 @@ class SparkSession private( */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { - new Dataset(wrapped, Range(start, end, step, numPartitions), Encoders.LONG) + new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } /** @@ -579,8 +580,8 @@ class SparkSession private( schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped) - Dataset.ofRows(wrapped, logicalPlan) + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + Dataset.ofRows(self, logicalPlan) } /** @@ -599,8 +600,8 @@ class SparkSession private( } else { rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } - val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped) - Dataset.ofRows(wrapped, logicalPlan) + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + Dataset.ofRows(self, logicalPlan) } @@ -749,7 +750,7 @@ class SparkSession private( } private def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.ofRows(wrapped, sessionState.catalog.lookupRelation(tableIdent)) + Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent)) } /** @@ -761,7 +762,7 @@ class SparkSession private( * @since 2.0.0 */ def tables(): DataFrame = { - Dataset.ofRows(wrapped, ShowTablesCommand(None, None)) + Dataset.ofRows(self, ShowTablesCommand(None, None)) } /** @@ -773,7 +774,7 @@ class SparkSession private( * @since 2.0.0 */ def tables(databaseName: String): DataFrame = { - Dataset.ofRows(wrapped, ShowTablesCommand(Some(databaseName), None)) + Dataset.ofRows(self, ShowTablesCommand(Some(databaseName), None)) } /** @@ -820,7 +821,7 @@ class SparkSession private( * @since 2.0.0 */ def sql(sqlText: String): DataFrame = { - Dataset.ofRows(wrapped, parseSql(sqlText)) + Dataset.ofRows(self, parseSql(sqlText)) } /** @@ -835,7 +836,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - def read: DataFrameReader = new DataFrameReader(wrapped) + def read: DataFrameReader = new DataFrameReader(self) // scalastyle:off @@ -906,7 +907,7 @@ class SparkSession private( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) - Dataset.ofRows(wrapped, LogicalRDD(schema.toAttributes, rowRdd)(wrapped)) + Dataset.ofRows(self, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 124ec09efd..f601138a9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -87,15 +87,15 @@ private[sql] class CacheManager extends Logging { if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { - val sqlContext = query.sqlContext + val sparkSession = query.sparkSession cachedData += CachedData( planToCache, InMemoryRelation( - sqlContext.conf.useCompression, - sqlContext.conf.columnBatchSize, + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(planToCache).executedPlan, + sparkSession.executePlan(planToCache).executedPlan, tableName)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 7afdf75f38..520ceaaaea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ @@ -75,15 +75,15 @@ object RDDConversions { /** Logical plan node for scanning data from an RDD. */ private[sql] case class LogicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow])(sqlContext: SQLContext) + rdd: RDD[InternalRow])(session: SparkSession) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil - override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil override def newInstance(): LogicalRDD.this.type = - LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] + LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type] override def sameResult(plan: LogicalPlan): Boolean = plan match { case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id @@ -95,7 +95,7 @@ private[sql] case class LogicalRDD( @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) } @@ -329,7 +329,8 @@ private[sql] object DataSourceScanExec { val outputPartitioning = { val bucketSpec = relation match { // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec + case r: HadoopFsRelation + if r.sparkSession.sessionState.conf.bucketingEnabled => r.bucketSpec case _ => None } @@ -349,7 +350,7 @@ private[sql] object DataSourceScanExec { relation match { case r: HadoopFsRelation - if r.fileFormat.supportBatch(r.sqlContext, StructType.fromAttributes(output)) => + if r.fileFormat.supportBatch(r.sparkSession, StructType.fromAttributes(output)) => BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) case _ => RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index bb83676b7d..d3d83b0218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} @@ -39,39 +39,41 @@ import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampT * While this is not a public class, we should avoid changing the function names for the sake of * changing them, because a lot of developers use the feature for debugging. */ -class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { +class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // TODO: Move the planner an optimizer into here from SessionState. - protected def planner = sqlContext.sessionState.planner - - def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae + protected def planner = sparkSession.sessionState.planner + + def assertAnalyzed(): Unit = { + try sparkSession.sessionState.analyzer.checkAnalysis(analyzed) catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } } def assertSupported(): Unit = { - if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + if (sparkSession.sessionState.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { UnsupportedOperationChecker.checkForBatch(analyzed) } } lazy val analyzed: LogicalPlan = { - SQLContext.setActive(sqlContext) - sqlContext.sessionState.analyzer.execute(logical) + SQLContext.setActive(sparkSession.wrapped) + sparkSession.sessionState.analyzer.execute(logical) } lazy val withCachedData: LogicalPlan = { assertAnalyzed() assertSupported() - sqlContext.cacheManager.useCachedData(analyzed) + sparkSession.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { - SQLContext.setActive(sqlContext) + SQLContext.setActive(sparkSession.wrapped) planner.plan(ReturnAnswer(optimizedPlan)).next() } @@ -93,10 +95,10 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, - PlanSubqueries(sqlContext), - EnsureRequirements(sqlContext.conf), - CollapseCodegenStages(sqlContext.conf), - ReuseExchange(sqlContext.conf)) + PlanSubqueries(sparkSession), + EnsureRequirements(sparkSession.sessionState.conf), + CollapseCodegenStages(sparkSession.sessionState.conf), + ReuseExchange(sparkSession.sessionState.conf)) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -110,7 +112,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { case ExecutedCommandExec(desc: DescribeTableCommand) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. - desc.run(sqlContext).map { + desc.run(sparkSession).map { case Row(name: String, dataType: String, comment) => Seq(name, dataType, Option(comment.asInstanceOf[String]).getOrElse("")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 0a11b16d0e..397d66b311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.util.Utils @@ -38,21 +38,22 @@ private[sql] object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = { - val sc = sqlContext.sparkContext + sparkSession: SparkSession, + queryExecution: QueryExecution)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) if (oldExecutionId == null) { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index e28e456662..861ff3cd15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala index e6c5351106..b6f7808398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} @@ -35,8 +35,8 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} */ case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val sessionState = sqlContext.sessionState + override def run(sparkSession: SparkSession): Seq[Row] = { + val sessionState = sparkSession.sessionState val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) @@ -77,7 +77,7 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { catalogTable.storage.locationUri.map { p => val path = new Path(p) try { - val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf) + val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf) calculateTableSize(fs, path) } catch { case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala index 39e441f1c3..bf66ea46fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.types.StringType @@ -29,7 +29,7 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.runNativeSql(sql).map(Row(_)) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.runNativeSql(sql).map(Row(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 4daf9e916a..952a0d676f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import java.util.NoSuchElementException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -43,10 +43,10 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm schema.toAttributes } - private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { + private val (_output, runFunc): (Seq[Attribute], SparkSession => Seq[Row]) = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") @@ -56,14 +56,14 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm "determining the number of reducers is not supported." throw new IllegalArgumentException(msg) } else { - sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) + sparkSession.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) } } (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + s"External sort will continue to be used.") @@ -72,7 +72,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + @@ -82,7 +82,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " + s"will be ignored. Tungsten will continue to be used.") @@ -91,7 +91,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " + s"will be ignored. Codegen will continue to be used.") @@ -100,7 +100,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " + s"will be ignored. Unsafe mode will continue to be used.") @@ -109,7 +109,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.SORTMERGE_JOIN} is deprecated and " + s"will be ignored. Sort merge join will continue to be used.") @@ -118,7 +118,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) case Some((SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED} is " + s"deprecated and will be ignored. Vectorized parquet reader will be used instead.") @@ -128,25 +128,25 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm // Configures a single property. case Some((key, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.setConf(key, value) + val runFunc = (sparkSession: SparkSession) => { + sparkSession.setConf(key, value) Seq(Row(key, value)) } (keyValueOutput, runFunc) // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) - // Queries all key-value pairs that are set in the SQLConf of the sqlContext. + // Queries all key-value pairs that are set in the SQLConf of the sparkSession. case None => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq + val runFunc = (sparkSession: SparkSession) => { + sparkSession.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq } (keyValueOutput, runFunc) // Queries all properties along with their default values and docs that are defined in the - // SQLConf of the sqlContext. + // SQLConf of the sparkSession. case Some(("-v", None)) => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => + val runFunc = (sparkSession: SparkSession) => { + sparkSession.sessionState.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => Row(key, defaultValue, doc) } } @@ -158,19 +158,21 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm // Queries the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { logWarning( s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") - Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) + Seq(Row( + SQLConf.SHUFFLE_PARTITIONS.key, + sparkSession.sessionState.conf.numShufflePartitions.toString)) } (keyValueOutput, runFunc) // Queries a single property. case Some((key, None)) => - val runFunc = (sqlContext: SQLContext) => { + val runFunc = (sparkSession: SparkSession) => { val value = - try sqlContext.getConf(key) catch { + try sparkSession.getConf(key) catch { case _: NoSuchElementException => "" } Seq(Row(key, value)) @@ -180,6 +182,6 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm override val output: Seq[Attribute] = _output - override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) + override def run(sparkSession: SparkSession): Seq[Row] = runFunc(sparkSession) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 5be5d0c2b0..c283bd61d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Dataset, Row, SQLContext} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -28,15 +28,15 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - sqlContext.registerDataFrameAsTable(Dataset.ofRows(sqlContext, logicalPlan), tableName) + sparkSession.registerDataFrameAsTable(Dataset.ofRows(sparkSession, logicalPlan), tableName) } - sqlContext.cacheTable(tableName) + sparkSession.cacheTable(tableName) if (!isLazy) { // Performs eager caching - sqlContext.table(tableName).count() + sparkSession.table(tableName).count() } Seq.empty[Row] @@ -48,8 +48,8 @@ case class CacheTableCommand( case class UncacheTableCommand(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.table(tableName).unpersist(blocking = false) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.table(tableName).unpersist(blocking = false) Seq.empty[Row] } @@ -61,8 +61,8 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { */ case object ClearCacheCommand extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.clearCache() + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.clearCache() Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 0fd7fa92a3..7bb59b7803 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row, SQLContext} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty - def run(sqlContext: SQLContext): Seq[Row] + def run(sparkSession: SparkSession): Seq[Row] } /** @@ -54,7 +54,7 @@ private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkP */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow]) + cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } override def output: Seq[Attribute] = cmd.output @@ -97,8 +97,8 @@ case class ExplainCommand( extends RunnableCommand { // Run through the optimizer to generate the physical plan. - override def run(sqlContext: SQLContext): Seq[Row] = try { - val queryExecution = sqlContext.executePlan(logicalPlan) + override def run(sparkSession: SparkSession): Seq[Row] = try { + val queryExecution = sparkSession.executePlan(logicalPlan) val outputString = if (codegen) { codegenString(queryExecution.executedPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 0ef1d1d688..31900b4993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -55,7 +55,7 @@ case class CreateDataSourceTableCommand( managedIfNoPath: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Since we are saving metadata to metastore, we need to check if metastore supports // the table name and database name we have for this query. MetaStoreUtils.validateName // is the method used by Hive to check if a table name or a database name is valid for @@ -72,7 +72,7 @@ case class CreateDataSourceTableCommand( } val tableName = tableIdent.unquotedString - val sessionState = sqlContext.sessionState + val sessionState = sparkSession.sessionState if (sessionState.catalog.tableExists(tableIdent)) { if (ignoreIfExists) { @@ -93,14 +93,14 @@ case class CreateDataSourceTableCommand( // Create the relation to validate the arguments before writing the metadata to the metastore. DataSource( - sqlContext = sqlContext, + sparkSession = sparkSession, userSpecifiedSchema = userSpecifiedSchema, className = provider, bucketSpec = None, options = optionsWithPath).resolveRelation() CreateDataSourceTableUtils.createDataSourceTable( - sqlContext = sqlContext, + sparkSession = sparkSession, tableIdent = tableIdent, userSpecifiedSchema = userSpecifiedSchema, partitionColumns = Array.empty[String], @@ -136,7 +136,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Since we are saving metadata to metastore, we need to check if metastore supports // the table name and database name we have for this query. MetaStoreUtils.validateName // is the method used by Hive to check if a table name or a database name is valid for @@ -153,7 +153,7 @@ case class CreateDataSourceTableAsSelectCommand( } val tableName = tableIdent.unquotedString - val sessionState = sqlContext.sessionState + val sessionState = sparkSession.sessionState var createMetastoreTable = false var isExternal = true val optionsWithPath = @@ -165,7 +165,7 @@ case class CreateDataSourceTableAsSelectCommand( } var existingSchema = None: Option[StructType] - if (sqlContext.sessionState.catalog.tableExists(tableIdent)) { + if (sparkSession.sessionState.catalog.tableExists(tableIdent)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -180,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand( case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val dataSource = DataSource( - sqlContext = sqlContext, + sparkSession = sparkSession, userSpecifiedSchema = Some(query.schema.asNullable), partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -197,7 +197,7 @@ case class CreateDataSourceTableAsSelectCommand( throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - sqlContext.sql(s"DROP TABLE IF EXISTS $tableName") + sparkSession.sql(s"DROP TABLE IF EXISTS $tableName") // Need to create the table again. createMetastoreTable = true } @@ -206,7 +206,7 @@ case class CreateDataSourceTableAsSelectCommand( createMetastoreTable = true } - val data = Dataset.ofRows(sqlContext, query) + val data = Dataset.ofRows(sparkSession, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. case Some(s) => data.selectExpr(s.fieldNames: _*) @@ -215,7 +215,7 @@ case class CreateDataSourceTableAsSelectCommand( // Create the relation based on the data of df. val dataSource = DataSource( - sqlContext, + sparkSession, className = provider, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -228,7 +228,7 @@ case class CreateDataSourceTableAsSelectCommand( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). CreateDataSourceTableUtils.createDataSourceTable( - sqlContext = sqlContext, + sparkSession = sparkSession, tableIdent = tableIdent, userSpecifiedSchema = Some(result.schema), partitionColumns = partitionColumns, @@ -260,7 +260,7 @@ object CreateDataSourceTableUtils extends Logging { } def createDataSourceTable( - sqlContext: SQLContext, + sparkSession: SparkSession, tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], @@ -275,7 +275,7 @@ object CreateDataSourceTableUtils extends Logging { // stored into a single metastore SerDe property. In this case, we split the JSON string and // store each part as a separate SerDe property. userSpecifiedSchema.foreach { schema => - val threshold = sqlContext.sessionState.conf.schemaStringLengthThreshold + val threshold = sparkSession.sessionState.conf.schemaStringLengthThreshold val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq @@ -329,10 +329,10 @@ object CreateDataSourceTableUtils extends Logging { CatalogTableType.MANAGED_TABLE } - val maybeSerDe = HiveSerDe.sourceToSerDe(provider, sqlContext.sessionState.conf) + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, sparkSession.sessionState.conf) val dataSource = DataSource( - sqlContext, + sparkSession, userSpecifiedSchema = userSpecifiedSchema, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -432,7 +432,7 @@ object CreateDataSourceTableUtils extends Logging { // specific way. try { logInfo(message) - sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false) + sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false) } catch { case NonFatal(e) => val warningMessage = @@ -440,13 +440,13 @@ object CreateDataSourceTableUtils extends Logging { s"it into Hive metastore in Spark SQL specific format." logWarning(warningMessage, e) val table = newSparkSQLSpecificMetastoreTable() - sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false) + sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false) } case (None, message) => logWarning(message) val table = newSparkSQLSpecificMetastoreTable() - sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false) + sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index 33cc10d53a..cefe0f6e62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types.StringType @@ -38,10 +38,10 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl AttributeReference("result", StringType, nullable = false)() :: Nil } - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val databases = - databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases()) + databasePattern.map(catalog.listDatabases).getOrElse(catalog.listDatabases()) databases.map { d => Row(d) } } } @@ -55,8 +55,8 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl */ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.catalog.setCurrentDatabase(databaseName) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 85f0066f3b..f5aa8fb6fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} @@ -38,8 +38,8 @@ import org.apache.spark.sql.types._ */ abstract class NativeDDLCommand(val sql: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.runNativeSql(sql) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.runNativeSql(sql) } override val output: Seq[Attribute] = { @@ -66,8 +66,8 @@ case class CreateDatabase( props: Map[String, String]) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog catalog.createDatabase( CatalogDatabase( databaseName, @@ -104,8 +104,8 @@ case class DropDatabase( cascade: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) Seq.empty[Row] } @@ -126,8 +126,8 @@ case class AlterDatabaseProperties( props: Map[String, String]) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName) catalog.alterDatabase(db.copy(properties = db.properties ++ props)) @@ -152,9 +152,9 @@ case class DescribeDatabase( extended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { val dbMetadata: CatalogDatabase = - sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName) + sparkSession.sessionState.catalog.getDatabaseMetadata(databaseName) val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: @@ -193,8 +193,8 @@ case class DropTable( ifExists: Boolean, isView: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (!catalog.tableExists(tableName)) { if (!ifExists) { val objectName = if (isView) "View" else "Table" @@ -213,7 +213,7 @@ case class DropTable( case _ => }) try { - sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + sparkSession.cacheManager.tryUncacheQuery(sparkSession.table(tableName.quotedString)) } catch { case NonFatal(e) => log.warn(s"${e.getMessage}", e) } @@ -239,8 +239,8 @@ case class AlterTableSetProperties( isView: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) val newProperties = table.properties ++ properties @@ -271,8 +271,8 @@ case class AlterTableUnsetProperties( isView: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { @@ -315,8 +315,8 @@ case class AlterTableSerDeProperties( require(serdeClassName.isDefined || serdeProperties.isDefined, "alter table attempted to set neither serde class name nor serde properties") - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) // Do not support setting serde for datasource tables if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { @@ -350,8 +350,8 @@ case class AlterTableAddPartition( ifNotExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -381,8 +381,8 @@ case class AlterTableRenamePartition( newPartition: TablePartitionSpec) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.catalog.renamePartitions( + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.catalog.renamePartitions( tableName, Seq(oldPartition), Seq(newPartition)) Seq.empty[Row] } @@ -409,8 +409,8 @@ case class AlterTableDropPartition( ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -446,8 +446,8 @@ case class AlterTableSetLocation( location: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) partitionSpec match { case Some(spec) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 89ccacdc73..5aa779ddeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogFunction import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionInfo} @@ -47,8 +47,8 @@ case class CreateFunction( isTemp: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException( @@ -99,7 +99,7 @@ case class DescribeFunction( } } - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. functionName.toLowerCase match { case "<>" => @@ -116,7 +116,7 @@ case class DescribeFunction( Row(s"Function: case") :: Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + s"When a = b, returns c; when a = d, return e; else return f") :: Nil - case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { + case _ => sparkSession.sessionState.functionRegistry.lookupFunction(functionName) match { case Some(info) => val result = Row(s"Function: ${info.getName}") :: @@ -149,8 +149,8 @@ case class DropFunction( isTemp: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException( @@ -187,12 +187,12 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[Row] = { - val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + override def run(sparkSession: SparkSession): Seq[Row] = { + val dbName = db.getOrElse(sparkSession.sessionState.catalog.getCurrentDatabase) // If pattern is not specified, we use '*', which is used to // match any sequence of characters (including no characters). val functionNames = - sqlContext.sessionState.catalog + sparkSession.sessionState.catalog .listFunctions(dbName, pattern.getOrElse("*")) .map(_.unquotedString) // The session catalog caches some persistent functions in the FunctionRegistry diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index fc7ecb11ec..29bcb30592 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -31,8 +31,8 @@ case class AddJar(path: String) extends RunnableCommand { schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.addJar(path) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.addJar(path) Seq(Row(0)) } } @@ -41,8 +41,8 @@ case class AddJar(path: String) extends RunnableCommand { * Adds a file to the current session so it can be used. */ case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sparkContext.addFile(path) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sparkContext.addFile(path) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 5cac9d879f..700a704941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -22,7 +22,7 @@ import java.net.URI import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogTableType, ExternalCatalog} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -60,8 +60,8 @@ case class CreateTableLike( sourceTable: TableIdentifier, ifNotExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (!catalog.tableExists(sourceTable)) { throw new AnalysisException( s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") @@ -109,8 +109,8 @@ case class CreateTableLike( */ case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.catalog.createTable(table, ifNotExists) + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.catalog.createTable(table, ifNotExists) Seq.empty[Row] } @@ -132,8 +132,8 @@ case class AlterTableRename( isView: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog DDLUtils.verifyAlterTableType(catalog, oldName, isView) catalog.invalidateTable(oldName) catalog.renameTable(oldName, newName) @@ -158,8 +158,8 @@ case class LoadData( isOverwrite: Boolean, partition: Option[ExternalCatalog.TablePartitionSpec]) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (!catalog.tableExists(table)) { throw new AnalysisException( s"Table in LOAD DATA does not exist: '$table'") @@ -210,7 +210,7 @@ case class LoadData( // Follow Hive's behavior: // If no schema or authority is provided with non-local inpath, // we will use hadoop configuration "fs.default.name". - val defaultFSConf = sqlContext.sessionState.hadoopConf.get("fs.default.name") + val defaultFSConf = sparkSession.sessionState.hadoopConf.get("fs.default.name") val defaultFS = if (defaultFSConf == null) { new URI("") } else { @@ -285,9 +285,9 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean) new MetadataBuilder().putString("comment", "comment of the column").build())() ) - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { val result = new ArrayBuffer[Row] - sqlContext.sessionState.catalog.lookupRelation(table) match { + sparkSession.sessionState.catalog.lookupRelation(table) match { case catalogRelation: CatalogRelation => catalogRelation.catalogTable.schema.foreach { column => result += Row(column.name, column.dataType, column.comment.orNull) @@ -333,10 +333,10 @@ case class ShowTablesCommand( AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil } - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Since we need to return a Seq of rows, we will call getTables directly - // instead of calling tables in sqlContext. - val catalog = sqlContext.sessionState.catalog + // instead of calling tables in sparkSession. + val catalog = sparkSession.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) val tables = tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) @@ -368,13 +368,13 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio } } - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog if (catalog.isTemporaryTable(table)) { Seq.empty[Row] } else { - val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table) + val catalogTable = sparkSession.sessionState.catalog.getTableMetadata(table) propertyKey match { case Some(p) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 07cc4a9482..f42b56fdc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.SQLBuilder import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} @@ -62,14 +62,14 @@ case class CreateViewCommand( "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") } - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. - val qe = sqlContext.executePlan(child) + val qe = sparkSession.executePlan(child) qe.assertAnalyzed() val analyzedPlan = qe.analyzed require(tableDesc.schema == Nil || tableDesc.schema.length == analyzedPlan.output.length) - val sessionState = sqlContext.sessionState + val sessionState = sparkSession.sessionState if (sessionState.catalog.tableExists(tableIdentifier)) { if (allowExisting) { @@ -77,7 +77,7 @@ case class CreateViewCommand( // already exists. } else if (replace) { // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - sessionState.catalog.alterTable(prepareTable(sqlContext, analyzedPlan)) + sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) } else { // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already // exists. @@ -88,7 +88,7 @@ case class CreateViewCommand( } else { // Create the view if it doesn't exist. sessionState.catalog.createTable( - prepareTable(sqlContext, analyzedPlan), ignoreIfExists = false) + prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } Seq.empty[Row] @@ -98,9 +98,9 @@ case class CreateViewCommand( * Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize * SQL based on the analyzed plan, and also creates the proper schema for the view. */ - private def prepareTable(sqlContext: SQLContext, analyzedPlan: LogicalPlan): CatalogTable = { + private def prepareTable(sparkSession: SparkSession, analyzedPlan: LogicalPlan): CatalogTable = { val viewSQL: String = - if (sqlContext.conf.canonicalView) { + if (sparkSession.sessionState.conf.canonicalView) { val logicalPlan = if (tableDesc.schema.isEmpty) { analyzedPlan @@ -108,7 +108,7 @@ case class CreateViewCommand( val projectList = analyzedPlan.output.zip(tableDesc.schema).map { case (attr, col) => Alias(attr, col.name)() } - sqlContext.executePlan(Project(projectList, analyzedPlan)).analyzed + sparkSession.executePlan(Project(projectList, analyzedPlan)).analyzed } new SQLBuilder(logicalPlan).toSQL } else { @@ -134,7 +134,7 @@ case class CreateViewCommand( // Validate the view SQL - make sure we can parse it and analyze it. // If we cannot analyze the generated query, there is probably a bug in SQL generation. try { - sqlContext.sql(viewSQL).queryExecution.assertAnalyzed() + sparkSession.sql(viewSQL).queryExecution.assertAnalyzed() } catch { case NonFatal(e) => throw new RuntimeException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 4e7214ce83..ef626ef5fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -59,7 +59,7 @@ import org.apache.spark.util.Utils * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. */ case class DataSource( - sqlContext: SQLContext, + sparkSession: SparkSession, className: String, paths: Seq[String] = Nil, userSpecifiedSchema: Option[StructType] = None, @@ -131,15 +131,15 @@ case class DataSource( val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf) + val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) + val fileCatalog: FileCatalog = new HDFSFileCatalog(sparkSession, options, globbedPaths, None) userSpecifiedSchema.orElse { format.inferSchema( - sqlContext, + sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) }.getOrElse { @@ -151,7 +151,8 @@ case class DataSource( private def sourceSchema(): SourceInfo = { providingClass.newInstance() match { case s: StreamSourceProvider => - val (name, schema) = s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) + val (name, schema) = s.sourceSchema( + sparkSession.wrapped, userSpecifiedSchema, className, options) SourceInfo(name, schema) case format: FileFormat => @@ -171,7 +172,7 @@ case class DataSource( def createSource(metadataPath: String): Source = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options) + s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -183,16 +184,16 @@ case class DataSource( val newOptions = options.filterKeys(_ != "path") + ("basePath" -> path) val newDataSource = DataSource( - sqlContext, + sparkSession, paths = files, userSpecifiedSchema = Some(sourceInfo.schema), className = className, options = new CaseInsensitiveMap(newOptions)) - Dataset.ofRows(sqlContext, LogicalRelation(newDataSource.resolveRelation())) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) } new FileStreamSource( - sqlContext, metadataPath, path, sourceInfo.schema, dataFrameBuilder) + sparkSession, metadataPath, path, sourceInfo.schema, dataFrameBuilder) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -202,14 +203,14 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(): Sink = { providingClass.newInstance() match { - case s: StreamSinkProvider => s.createSink(sqlContext, options, partitionColumns) + case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - new FileStreamSink(sqlContext, path, format) + new FileStreamSink(sparkSession, path, format) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed writing") @@ -225,7 +226,7 @@ case class DataSource( case Seq(singlePath) => try { val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf) + val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf) val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) val res = fs.exists(metadataPath) res @@ -244,9 +245,9 @@ case class DataSource( val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => - dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema) + dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema) case (dataSource: RelationProvider, None) => - dataSource.createRelation(sqlContext, caseInsensitiveOptions) + dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") case (_: RelationProvider, Some(_)) => @@ -257,11 +258,10 @@ case class DataSource( case (format: FileFormat, _) if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = - new StreamFileCatalog(sqlContext, basePath) + val fileCatalog = new StreamFileCatalog(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( - sqlContext, + sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) }.getOrElse { @@ -271,7 +271,7 @@ case class DataSource( } HadoopFsRelation( - sqlContext, + sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema, @@ -284,7 +284,7 @@ case class DataSource( val allPaths = caseInsensitiveOptions.get("path") ++ paths val globbedPaths = allPaths.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf) + val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) @@ -311,11 +311,11 @@ case class DataSource( } val fileCatalog: FileCatalog = - new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) + new HDFSFileCatalog(sparkSession, options, globbedPaths, partitionSchema) val dataSchema = userSpecifiedSchema.map { schema => val equality = - if (sqlContext.conf.caseSensitiveAnalysis) { + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution } else { org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution @@ -324,7 +324,7 @@ case class DataSource( StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) }.orElse { format.inferSchema( - sqlContext, + sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) }.getOrElse { @@ -334,10 +334,10 @@ case class DataSource( } val enrichedOptions = - format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles()) + format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) HadoopFsRelation( - sqlContext, + sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, @@ -363,7 +363,7 @@ case class DataSource( providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) + dataSource.createRelation(sparkSession.wrapped, mode, options, data) case format: FileFormat => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; @@ -374,11 +374,11 @@ case class DataSource( val path = new Path(caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") })) - val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf) + val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf) path.makeQualified(fs.getUri, fs.getWorkingDirectory) } - val caseSensitive = sqlContext.conf.caseSensitiveAnalysis + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumnDataTypes( data.schema, partitionColumns, caseSensitive) @@ -421,7 +421,7 @@ case class DataSource( options, data.logicalPlan, mode) - sqlContext.executePlan(plan).toRdd + sparkSession.executePlan(plan).toRdd case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 60238bd515..f7f68b1eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.{Partition => RDDPartition, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileNameHolder, RDD} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch @@ -51,10 +51,10 @@ case class PartitionedFile( case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends RDDPartition class FileScanRDD( - @transient val sqlContext: SQLContext, + @transient private val sparkSession: SparkSession, readFunction: (PartitionedFile) => Iterator[InternalRow], @transient val filePartitions: Seq[FilePartition]) - extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + extends RDD[InternalRow](sparkSession.sparkContext, Nil) { override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 751daa0fe2..9e1308bed5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -74,14 +74,14 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { } val partitionColumns = - l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver) + l.resolve(files.partitionSchema, files.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = - l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver) + l.resolve(files.dataSchema, files.sparkSession.sessionState.analyzer.resolver) // Partition keys are not available in the statistics of the files. val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) @@ -107,7 +107,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") val readFile = files.fileFormat.buildReader( - sqlContext = files.sqlContext, + sparkSession = files.sparkSession, dataSchema = files.dataSchema, partitionSchema = files.partitionSchema, requiredSchema = prunedDataSchema, @@ -115,7 +115,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { options = files.options) val plannedPartitions = files.bucketSpec match { - case Some(bucketing) if files.sqlContext.conf.bucketingEnabled => + case Some(bucketing) if files.sparkSession.sessionState.conf.bucketingEnabled => logInfo(s"Planning with ${bucketing.numBuckets} buckets") val bucketed = selectedPartitions.flatMap { p => @@ -134,9 +134,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { } case _ => - val defaultMaxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes - val openCostInBytes = files.sqlContext.conf.filesOpenCostInBytes - val defaultParallelism = files.sqlContext.sparkContext.defaultParallelism + val defaultMaxSplitBytes = files.sparkSession.sessionState.conf.filesMaxPartitionBytes + val openCostInBytes = files.sparkSession.sessionState.conf.filesOpenCostInBytes + val defaultParallelism = files.sparkSession.sparkContext.defaultParallelism val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum val bytesPerCore = totalBytes / defaultParallelism val maxSplitBytes = Math.min(defaultMaxSplitBytes, @@ -195,7 +195,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { DataSourceScanExec.create( readDataColumns ++ partitionColumns, new FileScanRDD( - files.sqlContext, + files.sparkSession, readFile, plannedPartitions), files, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 37c2c4517c..7b15e49641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -32,15 +32,15 @@ private[sql] case class InsertIntoDataSource( overwrite: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = Dataset.ofRows(sqlContext, query) + val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) + sparkSession.cacheManager.invalidateCache(logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index a636ca2f29..b2483e69a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -68,7 +68,7 @@ private[sql] case class InsertIntoHadoopFsRelation( override def children: Seq[LogicalPlan] = query :: Nil - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { @@ -78,7 +78,7 @@ private[sql] case class InsertIntoHadoopFsRelation( s"cannot save to file.") } - val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf) + val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf) val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -111,14 +111,14 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionSet = AttributeSet(partitionColumns) val dataColumns = query.output.filterNot(partitionSet.contains) - val queryExecution = Dataset.ofRows(sqlContext, query).queryExecution - SQLExecution.withNewExecutionId(sqlContext, queryExecution) { + val queryExecution = Dataset.ofRows(sparkSession, query).queryExecution + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { val relation = WriteRelation( - sqlContext, + sparkSession, dataColumns.toStructType, qualifiedOutputPath.toString, - fileFormat.prepareWrite(sqlContext, _, options, dataColumns.toStructType), + fileFormat.prepareWrite(sparkSession, _, options, dataColumns.toStructType), bucketSpec) val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) { @@ -131,7 +131,7 @@ private[sql] case class InsertIntoHadoopFsRelation( dataColumns = dataColumns, inputSchema = query.output, PartitioningUtils.DEFAULT_PARTITION_NAME, - sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + sparkSession.getConf(SQLConf.PARTITION_MAX_FILES), isAppend) } @@ -140,7 +140,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - sqlContext.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _) + sparkSession.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _) writerContainer.commitJob() refreshFunction() } catch { case cause: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index b9527db6d0..3b064a5bc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow @@ -36,9 +36,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} + /** A container for all the details required when writing to a table. */ case class WriteRelation( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, path: String, prepareJobForWrite: Job => OutputWriterFactory, @@ -66,7 +67,7 @@ private[sql] abstract class BaseWriterContainer( @transient private val jobContext: JobContext = job private val speculationEnabled: Boolean = - relation.sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + relation.sparkSession.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) // The following fields are initialized and used on both driver and executor side. @transient protected var outputCommitter: OutputCommitter = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 7d407a7747..fb047ff867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -49,14 +49,14 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { val csvOptions = new CSVOptions(options) // TODO: Move filtering. val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString) - val rdd = baseRdd(sqlContext, csvOptions, paths) + val rdd = baseRdd(sparkSession, csvOptions, paths) val firstLine = findFirstLine(csvOptions, rdd) val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine) @@ -66,7 +66,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { firstRow.zipWithIndex.map { case (value, index) => s"C$index" } } - val parsedRdd = tokenRdd(sqlContext, csvOptions, header, paths) + val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) val schema = if (csvOptions.inferSchemaFlag) { CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue) } else { @@ -80,7 +80,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, @@ -103,8 +103,8 @@ class DefaultSource extends FileFormat with DataSourceRegister { val csvOptions = new CSVOptions(options) val headers = requiredSchema.fields.map(_.name) - val conf = new Configuration(sqlContext.sessionState.hadoopConf) - val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + val conf = new Configuration(sparkSession.sessionState.hadoopConf) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf)) (file: PartitionedFile) => { val lineIterator = { @@ -134,18 +134,18 @@ class DefaultSource extends FileFormat with DataSourceRegister { } private def baseRdd( - sqlContext: SQLContext, + sparkSession: SparkSession, options: CSVOptions, inputPaths: Seq[String]): RDD[String] = { - readText(sqlContext, options, inputPaths.mkString(",")) + readText(sparkSession, options, inputPaths.mkString(",")) } private def tokenRdd( - sqlContext: SQLContext, + sparkSession: SparkSession, options: CSVOptions, header: Array[String], inputPaths: Seq[String]): RDD[Array[String]] = { - val rdd = baseRdd(sqlContext, options, inputPaths) + val rdd = baseRdd(sparkSession, options, inputPaths) // Make sure firstLine is materialized before sending to executors val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null CSVRelation.univocityTokenizer(rdd, header, firstLine, options) @@ -168,14 +168,14 @@ class DefaultSource extends FileFormat with DataSourceRegister { } private def readText( - sqlContext: SQLContext, + sparkSession: SparkSession, options: CSVOptions, location: String): RDD[String] = { if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sqlContext.sparkContext.textFile(location) + sparkSession.sparkContext.textFile(location) } else { val charset = options.charset - sqlContext.sparkContext + sparkSession.sparkContext .hadoopFile[LongWritable, Text, TextInputFormat](location) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7e94bbef8..7d0a3d9756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -74,15 +74,15 @@ case class CreateTempTableUsing( s"Temporary table '$tableIdent' should not have specified a database") } - def run(sqlContext: SQLContext): Seq[Row] = { + def run(sparkSession: SparkSession): Seq[Row] = { val dataSource = DataSource( - sqlContext, + sparkSession, userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sqlContext.sessionState.catalog.createTempTable( + sparkSession.sessionState.catalog.createTempTable( tableIdent.table, - Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan, + Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan, overrideIfExists = true) Seq.empty[Row] @@ -102,18 +102,18 @@ case class CreateTempTableUsingAsSelect( s"Temporary table '$tableIdent' should not have specified a database") } - override def run(sqlContext: SQLContext): Seq[Row] = { - val df = Dataset.ofRows(sqlContext, query) + override def run(sparkSession: SparkSession): Seq[Row] = { + val df = Dataset.ofRows(sparkSession, query) val dataSource = DataSource( - sqlContext, + sparkSession, className = provider, partitionColumns = partitionColumns, bucketSpec = None, options = options) val result = dataSource.write(mode, df) - sqlContext.sessionState.catalog.createTempTable( + sparkSession.sessionState.catalog.createTempTable( tableIdent.table, - Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan, + Dataset.ofRows(sparkSession, LogicalRelation(result)).logicalPlan, overrideIfExists = true) Seq.empty[Row] @@ -123,23 +123,23 @@ case class CreateTempTableUsingAsSelect( case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.sessionState.catalog.refreshTable(tableIdent) + sparkSession.sessionState.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.sessionState.catalog.lookupRelation(tableIdent) + val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent) // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty + val isCached = sparkSession.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { // Create a data frame to represent the table. // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sqlContext, logicalPlan) + val df = Dataset.ofRows(sparkSession, logicalPlan) // Uncache the logicalPlan. - sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) + sparkSession.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. - sqlContext.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.cacheManager.cacheQuery(df, Some(tableIdent.table)) } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 731b0047e5..2628788ad3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -119,7 +119,7 @@ abstract class OutputWriter { * @param options Configuration used when reading / writing data. */ case class HadoopFsRelation( - sqlContext: SQLContext, + sparkSession: SparkSession, location: FileCatalog, partitionSchema: StructType, dataSchema: StructType, @@ -127,6 +127,8 @@ case class HadoopFsRelation( fileFormat: FileFormat, options: Map[String, String]) extends BaseRelation with FileRelation { + override def sqlContext: SQLContext = sparkSession.wrapped + val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet StructType(dataSchema ++ partitionSchema.filterNot { column => @@ -160,7 +162,7 @@ trait FileFormat { * Spark will require that user specify the schema manually. */ def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] @@ -169,7 +171,7 @@ trait FileFormat { * can be useful for collecting necessary global information for scanning input data. */ def prepareRead( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Map[String, String] = options @@ -179,7 +181,7 @@ trait FileFormat { * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. */ def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory @@ -189,7 +191,7 @@ trait FileFormat { * * TODO: we should just have different traits for the different formats. */ - def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { + def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = { false } @@ -210,7 +212,7 @@ trait FileFormat { * @return */ def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, @@ -265,13 +267,13 @@ trait FileCatalog { * discovered partitions */ class HDFSFileCatalog( - val sqlContext: SQLContext, - val parameters: Map[String, String], - val paths: Seq[Path], - val partitionSchema: Option[StructType]) + sparkSession: SparkSession, + parameters: Map[String, String], + override val paths: Seq[Path], + partitionSchema: Option[StructType]) extends FileCatalog with Logging { - private val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf) + private val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf) var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] @@ -339,8 +341,8 @@ class HDFSFileCatalog( def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession.sparkContext) } else { val statuses: Seq[FileStatus] = paths.flatMap { path => val fs = path.getFileSystem(hadoopConf) @@ -412,7 +414,7 @@ class HDFSFileCatalog( PartitioningUtils.parsePartitions( leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled(), basePaths = basePaths) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index 4dcd261f5c..6ff50a3c61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -54,6 +54,6 @@ class DefaultSource extends RelationProvider with DataSourceRegister { val parts = JDBCRelation.columnPartition(partitionInfo) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) + JDBCRelation(url, table, parts, properties)(sqlContext.sparkSession) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 9e336422d1..bcf70fdc4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -87,11 +87,13 @@ private[sql] case class JDBCRelation( url: String, table: String, parts: Array[Partition], - properties: Properties = new Properties())(@transient val sqlContext: SQLContext) + properties: Properties = new Properties())(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { + override def sqlContext: SQLContext = sparkSession.wrapped + override val needConversion: Boolean = false override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) @@ -104,7 +106,7 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( - sqlContext.sparkContext, + sparkSession.sparkContext, schema, url, properties, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 580a0e1de6..f9c34c6bb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -44,7 +44,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "json" override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { if (files.isEmpty) { @@ -53,14 +53,14 @@ class DefaultSource extends FileFormat with DataSourceRegister { val parsedOptions: JSONOptions = new JSONOptions(options) val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val jsonFiles = files.filterNot { status => val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.toArray val jsonSchema = InferSchema.infer( - createBaseRdd(sqlContext, jsonFiles), + createBaseRdd(sparkSession, jsonFiles), columnNameOfCorruptRecord, parsedOptions) checkConstraints(jsonSchema) @@ -70,7 +70,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -92,19 +92,19 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { - val conf = new Configuration(sqlContext.sessionState.hadoopConf) + val conf = new Configuration(sparkSession.sessionState.hadoopConf) val broadcastedConf = - sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf)) val parsedOptions: JSONOptions = new JSONOptions(options) val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() @@ -125,8 +125,10 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = { - val job = Job.getInstance(sqlContext.sessionState.hadoopConf) + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[String] = { + val job = Job.getInstance(sparkSession.sessionState.hadoopConf) val conf = job.getConfiguration val paths = inputPaths.map(_.getPath) @@ -135,7 +137,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { FileInputFormat.setInputPaths(job, paths: _*) } - sqlContext.sparkContext.hadoopRDD( + sparkSession.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 28c6664085..b156581564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -65,12 +65,12 @@ private[sql] class DefaultSource override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parquetOptions = new ParquetOptions(options, sqlContext.sessionState.conf) + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -110,15 +110,15 @@ private[sql] class DefaultSource // and `CatalystWriteSupport` (writing actual rows to Parquet files). conf.set( SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlContext.conf.isParquetBinaryAsString.toString) + sparkSession.sessionState.conf.isParquetBinaryAsString.toString) conf.set( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlContext.conf.isParquetINT96AsTimestamp.toString) + sparkSession.sessionState.conf.isParquetINT96AsTimestamp.toString) conf.set( SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - sqlContext.conf.writeLegacyParquetFormat.toString) + sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) @@ -135,7 +135,7 @@ private[sql] class DefaultSource } def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, parameters: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { // Should we merge schemas from all Parquet part-files? @@ -143,10 +143,10 @@ private[sql] class DefaultSource parameters .get(ParquetRelation.MERGE_SCHEMA) .map(_.toBoolean) - .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + .getOrElse(sparkSession.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) val mergeRespectSummaries = - sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + sparkSession.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) val filesByType = splitFiles(files) @@ -218,7 +218,7 @@ private[sql] class DefaultSource .orElse(filesByType.data.headOption) .toSeq } - ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) + ParquetRelation.mergeSchemasInParallel(filesToTouch, sparkSession) } case class FileTypes( @@ -249,21 +249,21 @@ private[sql] class DefaultSource /** * Returns whether the reader will return the rows as batch or not. */ - override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { - val conf = SQLContext.getActive().get.conf + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && schema.forall(_.dataType.isInstanceOf[AtomicType]) } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { - val parquetConf = new Configuration(sqlContext.sessionState.hadoopConf) + val parquetConf = new Configuration(sparkSession.sessionState.hadoopConf) parquetConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) parquetConf.set( CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -281,13 +281,13 @@ private[sql] class DefaultSource // Sets flags for `CatalystSchemaConverter` parquetConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, - sqlContext.conf.getConf(SQLConf.PARQUET_BINARY_AS_STRING)) + sparkSession.getConf(SQLConf.PARQUET_BINARY_AS_STRING)) parquetConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) + sparkSession.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) // Try to push down filters when filter push-down is enabled. - val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { + val pushed = if (sparkSession.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` @@ -299,16 +299,17 @@ private[sql] class DefaultSource } val broadcastedConf = - sqlContext.sparkContext.broadcast(new SerializableConfiguration(parquetConf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(parquetConf)) // TODO: if you move this into the closure it reverts to the default values. // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) - val enableVectorizedReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && + val enableVectorizedReader: Boolean = + sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - val returningBatch = supportBatch(sqlContext, resultSchema) + val returningBatch = supportBatch(sparkSession, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -507,13 +508,13 @@ private[sql] object ParquetRelation extends Logging { } private[parquet] def readSchema( - footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { def parseParquetSchema(schema: MessageType): StructType = { val converter = new CatalystSchemaConverter( - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.writeLegacyParquetFormat) + sparkSession.sessionState.conf.isParquetBinaryAsString, + sparkSession.sessionState.conf.isParquetBinaryAsString, + sparkSession.sessionState.conf.writeLegacyParquetFormat) converter.convert(schema) } @@ -644,11 +645,13 @@ private[sql] object ParquetRelation extends Logging { * S3 nodes). */ def mergeSchemasInParallel( - filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat - val serializedConf = new SerializableConfiguration(sqlContext.sessionState.hadoopConf) + filesToTouch: Seq[FileStatus], + sparkSession: SparkSession): Option[StructType] = { + val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat + val serializedConf = + new SerializableConfiguration(sparkSession.sessionState.hadoopConf) // !! HACK ALERT !! // @@ -665,7 +668,7 @@ private[sql] object ParquetRelation extends Logging { // Issues a Spark job to read Parquet schema in parallel. val partiallyMergedSchemas = - sqlContext + sparkSession .sparkContext .parallelize(partialFileStatusInfo) .mapPartitions { iterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5b8dc4a3ee..b622f85941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} @@ -30,12 +30,12 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. */ -private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[LogicalPlan] { +private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if u.tableIdentifier.database.isDefined => try { val dataSource = DataSource( - sqlContext, + sparkSession, paths = u.tableIdentifier.table :: Nil, className = u.tableIdentifier.database.get) val plan = LogicalRelation(dataSource.resolveRelation()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index f7ac1ac8e4..a0d680c708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} @@ -52,12 +52,12 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType)) override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -84,15 +84,15 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { - val conf = new Configuration(sqlContext.sessionState.hadoopConf) + val conf = new Configuration(sparkSession.sessionState.hadoopConf) val broadcastedConf = - sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf)) file => { val unsafeRow = new UnsafeRow(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 8c2231335c..34bd243d58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging { StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - Dataset.ofRows(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d603f63a08..9c0406168e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging { } val schema = StructType(StructField(tableName, StringType) +: headerNames) - Dataset.ofRows(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index a86108862f..61c9c88cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -22,7 +22,7 @@ import java.util.UUID import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.execution.datasources.FileFormat object FileStreamSink { @@ -38,14 +38,14 @@ object FileStreamSink { * in the log. */ class FileStreamSink( - sqlContext: SQLContext, + sparkSession: SparkSession, path: String, fileFormat: FileFormat) extends Sink with Logging { private val basePath = new Path(path) private val logPath = new Path(basePath, FileStreamSink.metadataDir) - private val fileLog = new FileStreamSinkLog(sqlContext, logPath.toUri.toString) - private val fs = basePath.getFileSystem(sqlContext.sessionState.hadoopConf) + private val fileLog = new FileStreamSinkLog(sparkSession, logPath.toUri.toString) + private val fs = basePath.getFileSystem(sparkSession.sessionState.hadoopConf) override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { @@ -73,7 +73,7 @@ class FileStreamSink( private def writeFiles(data: DataFrame): Array[Path] = { val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString data.write.parquet(file) - sqlContext.read + sparkSession.read .schema(data.schema) .parquet(file) .inputFiles diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 6c5449a928..c548fbd369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -25,7 +25,7 @@ import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.json4s.jackson.Serialization.{read, write} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf /** @@ -66,8 +66,8 @@ case class SinkFileStatus( * When the reader uses `allFiles` to list all files, this method only returns the visible files * (drops the deleted files). */ -class FileStreamSinkLog(sqlContext: SQLContext, path: String) - extends HDFSMetadataLog[Seq[SinkFileStatus]](sqlContext, path) { +class FileStreamSinkLog(sparkSession: SparkSession, path: String) + extends HDFSMetadataLog[Seq[SinkFileStatus]](sparkSession, path) { import FileStreamSinkLog._ @@ -80,11 +80,11 @@ class FileStreamSinkLog(sqlContext: SQLContext, path: String) * a live lock may happen if the compaction happens too frequently: one processing keeps deleting * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. */ - private val fileCleanupDelayMs = sqlContext.getConf(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + private val fileCleanupDelayMs = sparkSession.getConf(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) - private val isDeletingExpiredLog = sqlContext.getConf(SQLConf.FILE_SINK_LOG_DELETION) + private val isDeletingExpiredLog = sparkSession.getConf(SQLConf.FILE_SINK_LOG_DELETION) - private val compactInterval = sqlContext.getConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) + private val compactInterval = sparkSession.getConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) require(compactInterval > 0, s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + "to a positive value.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index aeb64c929c..e22a05bd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -32,14 +32,14 @@ import org.apache.spark.util.collection.OpenHashSet * TODO Clean up the metadata files periodically */ class FileStreamSource( - sqlContext: SQLContext, + sparkSession: SparkSession, metadataPath: String, path: String, override val schema: StructType, dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { - private val fs = new Path(path).getFileSystem(sqlContext.sessionState.hadoopConf) - private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath) + private val fs = new Path(path).getFileSystem(sparkSession.sessionState.hadoopConf) + private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) private val seenFiles = new OpenHashSet[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index dd6760d341..ddba3ccb1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** @@ -45,7 +45,7 @@ import org.apache.spark.sql.SQLContext * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing * files in a directory always shows the latest files. */ -class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) +class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) extends MetadataLog[T] with Logging { @@ -65,7 +65,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) override def accept(path: Path): Boolean = isBatchFile(path) } - private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance() + private val serializer = new JavaSerializer(sparkSession.sparkContext.conf).newInstance() protected def batchIdToPath(batchId: Long): Path = { new Path(metadataPath, batchId.toString) @@ -212,7 +212,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) } private def createFileManager(): FileManager = { - val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf) + val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf) try { new FileContextManager(metadataPath, hadoopConf) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a1a1108447..b89144d727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.OutputMode import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -28,20 +28,21 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, * plan incrementally. Possibly preserving state in between each execution. */ class IncrementalExecution( - ctx: SQLContext, + sparkSession: SparkSession, logicalPlan: LogicalPlan, outputMode: OutputMode, checkpointLocation: String, - currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + currentBatchId: Long) + extends QueryExecution(sparkSession, logicalPlan) { // TODO: make this always part of planning. - val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy :: Nil // Modified planner with stateful operations. override def planner: SparkPlanner = new SparkPlanner( - sqlContext.sparkContext, - sqlContext.conf, + sparkSession.sparkContext, + sparkSession.sessionState.conf, stateStrategy) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 2a1fa1ba62..ea3c73d984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -44,13 +44,14 @@ import org.apache.spark.util.UninterruptibleThread * and the results are committed transactionally to the given [[Sink]]. */ class StreamExecution( - override val sqlContext: SQLContext, + override val sparkSession: SparkSession, override val name: String, checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink, val outputMode: OutputMode, - val trigger: Trigger) extends ContinuousQuery with Logging { + val trigger: Trigger) + extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ private val awaitBatchLock = new Object @@ -108,7 +109,7 @@ class StreamExecution( * processed and the N-1th entry indicates which offsets have been durably committed to the sink. */ private val offsetLog = - new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets")) + new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets")) /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE @@ -158,7 +159,7 @@ class StreamExecution( startLatch.countDown() // While active, repeatedly attempt to run batches. - SQLContext.setActive(sqlContext) + SQLContext.setActive(sparkSession.wrapped) populateStartOffsets() logDebug(s"Stream running from $committedOffsets to $availableOffsets") triggerExecutor.execute(() => { @@ -181,7 +182,7 @@ class StreamExecution( logError(s"Query $name terminated with error", e) } finally { state = TERMINATED - sqlContext.streams.notifyQueryTermination(StreamExecution.this) + sparkSession.streams.notifyQueryTermination(StreamExecution.this) postEvent(new QueryTerminated(this)) terminationLatch.countDown() } @@ -317,7 +318,7 @@ class StreamExecution( val optimizerStart = System.nanoTime() lastExecution = new IncrementalExecution( - sqlContext, + sparkSession, newPlan, outputMode, checkpointFile("state"), @@ -328,7 +329,7 @@ class StreamExecution( logDebug(s"Optimized batch in ${optimizerTime}ms") val nextBatch = - new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { @@ -344,7 +345,7 @@ class StreamExecution( } private def postEvent(event: ContinuousQueryListener.Event) { - sqlContext.streams.postListenerEvent(event) + sparkSession.streams.postListenerEvent(event) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala index a08a4bb4c3..b2bc31634c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.execution.streaming import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.datasources.{FileCatalog, Partition, PartitionSpec} import org.apache.spark.sql.types.StructType -class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog with Logging { +class StreamFileCatalog(sparkSession: SparkSession, path: Path) extends FileCatalog with Logging { val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") - val metadataLog = new FileStreamSinkLog(sqlContext, metadataDirectory.toUri.toString) - val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf) + val metadataLog = new FileStreamSinkLog(sparkSession, metadataDirectory.toUri.toString) + val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf) override def paths: Seq[Path] = path :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3820968324..0d2a6dd929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -58,11 +58,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { - Dataset(sqlContext, logicalPlan) + Dataset(sqlContext.sparkSession, logicalPlan) } def toDF()(implicit sqlContext: SQLContext): DataFrame = { - Dataset.ofRows(sqlContext, logicalPlan) + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) } def addData(data: A*): Offset = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 1341e45483..4a1f12d685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} @@ -70,11 +70,11 @@ case class ScalarSubquery( /** * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ -case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan + val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f2448af991..fe63c80815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -932,7 +932,7 @@ object functions { * @since 1.5.0 */ def broadcast(df: DataFrame): DataFrame = { - Dataset.ofRows(df.sqlContext, BroadcastHint(df.logicalPlan)) + Dataset.ofRows(df.sparkSession, BroadcastHint(df.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 04ad729659..1bda572e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.util.ExecutionListenerManager /** - * A class that holds all session-specific state in a given [[SQLContext]]. + * A class that holds all session-specific state in a given [[SparkSession]]. */ -private[sql] class SessionState(ctx: SQLContext) { +private[sql] class SessionState(sparkSession: SparkSession) { // Note: These are all lazy vals because they depend on each other (e.g. conf) and we // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. @@ -48,10 +48,12 @@ private[sql] class SessionState(ctx: SQLContext) { * SQL-specific key-value configurations. */ lazy val conf: SQLConf = new SQLConf - lazy val hadoopConf: Configuration = new Configuration(ctx.sparkContext.hadoopConfiguration) + lazy val hadoopConf: Configuration = { + new Configuration(sparkSession.sparkContext.hadoopConfiguration) + } // Automatically extract `spark.sql.*` entries and put it in our SQLConf - setConf(SQLContext.getSQLProperties(ctx.sparkContext.getConf)) + setConf(SQLContext.getSQLProperties(sparkSession.sparkContext.getConf)) lazy val experimentalMethods = new ExperimentalMethods @@ -68,7 +70,7 @@ private[sql] class SessionState(ctx: SQLContext) { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { case JarResource => addJar(resource.uri) - case FileResource => ctx.sparkContext.addFile(resource.uri) + case FileResource => sparkSession.sparkContext.addFile(resource.uri) case ArchiveResource => throw new AnalysisException( "Archive is not allowed to be loaded. If YARN mode is used, " + @@ -82,10 +84,10 @@ private[sql] class SessionState(ctx: SQLContext) { * Internal catalog for managing table and database states. */ lazy val catalog = new SessionCatalog( - ctx.externalCatalog, - functionResourceLoader, - functionRegistry, - conf) + sparkSession.externalCatalog, + functionResourceLoader, + functionRegistry, + conf) /** * Interface exposed to the user for registering user-defined functions. @@ -100,7 +102,7 @@ private[sql] class SessionState(ctx: SQLContext) { override val extendedResolutionRules = PreInsertCastAndRename :: DataSourceAnalysis :: - (if (conf.runSQLonFile) new ResolveDataSource(ctx) :: Nil else Nil) + (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog)) } @@ -120,7 +122,7 @@ private[sql] class SessionState(ctx: SQLContext) { * Planner that converts optimized logical plans to physical plans. */ def planner: SparkPlanner = - new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -131,14 +133,16 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. */ - lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) + lazy val continuousQueryManager: ContinuousQueryManager = { + new ContinuousQueryManager(sparkSession) + } // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(ctx, plan) + def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) @@ -162,7 +166,7 @@ private[sql] class SessionState(ctx: SQLContext) { } def addJar(path: String): Unit = { - ctx.sparkContext.addJar(path) + sparkSession.sparkContext.addJar(path) } /** @@ -173,7 +177,7 @@ private[sql] class SessionState(ctx: SQLContext) { * in the external catalog. */ def analyze(tableName: String): Unit = { - AnalyzeTable(tableName).run(ctx) + AnalyzeTable(tableName).run(sparkSession) } def runNativeSql(sql: String): Seq[String] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 18e04c24a4..47b55e2547 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -57,7 +57,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { rows(0)) // dropna on an a dataframe with no column should return an empty data frame. - val empty = input.sqlContext.emptyDataFrame.select() + val empty = input.sparkSession.emptyDataFrame.select() assert(empty.na.drop().count() === 0L) // Make sure the columns are properly named. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4c18784126..681476b6e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -464,8 +464,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("callUDF in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - val sqlctx = df.sqlContext - sqlctx.udf.register("simpleUDF", (v: Int) => v * v) + df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v) checkAnswer( df.select($"id", callUDF("simpleUDF", $"value")), Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) @@ -618,7 +617,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("apply on query results (SPARK-5462)") { - val df = testData.sqlContext.sql("select key from testData") + val df = testData.sparkSession.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) } @@ -975,7 +974,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(sqlContext, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d9b374b792..df8b3b7d87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest { def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(sqlContext) + LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession) case l: LocalRelation => val origin = localRelations.pop() l.copy(data = origin.data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index c014f61679..dff6acc94b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 38318740a5..073e0b3f00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,7 +231,7 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - val execution = new QueryExecution(sqlContext, null) { + val execution = new QueryExecution(sqlContext.sparkSession, null) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index fb70dbd961..9da0af3a76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -280,7 +280,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi )) val fakeRDD = new FileScanRDD( - sqlContext, + sqlContext.sparkSession, (file: PartitionedFile) => Iterator.empty, Seq(partition) ) @@ -414,7 +414,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi l.copy(relation = r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) } - Dataset.ofRows(sqlContext, bucketed) + Dataset.ofRows(sqlContext.sparkSession, bucketed) } else { df } @@ -449,7 +449,7 @@ class TestFileFormat extends FileFormat { * Spark will require that user specify the schema manually. */ override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = Some( @@ -463,7 +463,7 @@ class TestFileFormat extends FileFormat { * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. */ override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -471,7 +471,7 @@ class TestFileFormat extends FileFormat { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1a7b62ca0a..e5588bec4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1316,7 +1316,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = DataSource( - sqlContext, + sqlContext.sparkSession, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, @@ -1324,7 +1324,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { options = Map("path" -> path)).resolveRelation() val d2 = DataSource( - sqlContext, + sqlContext.sparkSession, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index a164f4c733..df127d958e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(sqlContext, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath) f(sinkLog) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 22e011cfb7..129b5a8c36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -59,7 +59,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new HDFSMetadataLog[String](sqlContext, dir.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) @@ -86,14 +86,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) @@ -103,7 +103,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("HDFSMetadataLog: restart") { withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) @@ -111,7 +111,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { assert(metadataLog.getLatest() === Some(1 -> "batch1")) assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.get(1) === Some("batch1")) assert(metadataLog2.getLatest() === Some(1 -> "batch1")) @@ -126,7 +126,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { for (id <- 0 until 10) { new Thread() { override def run(): Unit = waiter { - val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + val metadataLog = + new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) try { var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) nextBatchId += 1 @@ -145,7 +146,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } waiter.await(timeout(10.seconds), dismissals(10)) - val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5f8514e1a2..612cfc7ec7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -28,13 +28,21 @@ class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) + SimpleDDLScan( + parameters("from").toInt, + parameters("TO").toInt, + parameters("Table"))(sqlContext.sparkSession) } } -case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) +case class SimpleDDLScan( + from: Int, + to: Int, + table: String)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = sparkSession.wrapped + override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, @@ -63,7 +71,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def buildScan(): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - sqlContext.sparkContext.parallelize(from to to).map { e => + sparkSession.sparkContext.parallelize(from to to).map { e => InternalRow(UTF8String.fromString(s"people$e"), e * 2) }.asInstanceOf[RDD[Row]] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 14707774cf..51d04f2f4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -32,14 +32,16 @@ class FilteredScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext.sparkSession) } } -case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan { + override def sqlContext: SQLContext = sparkSession.wrapped + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: @@ -115,7 +117,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } - sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => + sparkSession.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 9bb901bfb3..cd0256db43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -29,14 +29,16 @@ class PrunedScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext.sparkSession) } } -case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedScan { + override def sqlContext: SQLContext = sparkSession.wrapped + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: @@ -48,7 +50,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo case "b" => (i: Int) => Seq(i * 2) } - sqlContext.sparkContext.parallelize(from to to).map(i => + sparkSession.sparkContext.parallelize(from to to).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 94d032f4ee..4f6df54417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.execution.datasources.DataSource class ResolvedDataSourceSuite extends SparkFunSuite { private def getProvidingClass(name: String): Class[_] = - DataSource(sqlContext = null, className = name).providingClass + DataSource(sparkSession = null, className = name).providingClass test("jdbc") { assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 99f1661ad0..34b8726a92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -31,17 +31,21 @@ class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + SimpleScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext.sparkSession) } } -case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = sparkSession.wrapped + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = { + sparkSession.sparkContext.parallelize(from to to).map(Row(_)) + } } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -53,23 +57,27 @@ class AllDataTypesScanSource extends SchemaRelationProvider { parameters("option_with_underscores") parameters("option.with.dots") - AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) + AllDataTypesScan( + parameters("from").toInt, + parameters("TO").toInt, schema)(sqlContext.sparkSession) } } case class AllDataTypesScan( from: Int, to: Int, - userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = sparkSession.wrapped + override def schema: StructType = userSpecifiedSchema override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { - sqlContext.sparkContext.parallelize(from to to).map { i => + sparkSession.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", s"str_$i".getBytes(StandardCharsets.UTF_8), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fcfac359f3..5577c9f3ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -257,7 +257,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext, plan) + Dataset.ofRows(sqlContext.sparkSession, plan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index d270775af6..9799c6d42b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.{RuntimeConfigImpl, SessionState, SQLConf} * A special [[SQLContext]] prepared for testing. */ private[sql] class TestSQLContext( - @transient private val sparkSession: SparkSession, + @transient override val sparkSession: SparkSession, isRootContext: Boolean) extends SQLContext(sparkSession, isRootContext) { self => @@ -57,7 +57,7 @@ private[sql] class TestSQLContext( private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => @transient - protected[sql] override lazy val sessionState: SessionState = new SessionState(wrapped) { + protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { new SQLConf { clear() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index edb87b94ea..01b7cfbd2e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -24,7 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ @@ -45,16 +45,15 @@ import org.apache.spark.sql.types._ * This is still used for things like creating data source tables, but in the future will be * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ -private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { - private val conf = hive.conf - private val sessionState = hive.sessionState.asInstanceOf[HiveSessionState] - private val client = hive.sharedState.asInstanceOf[HiveSharedState].metadataHive - private val hiveconf = sessionState.hiveconf +private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { + private val conf = sparkSession.conf + private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] + private val client = sparkSession.sharedState.asInstanceOf[HiveSharedState].metadataHive /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getCurrentDatabase: String = hive.sessionState.catalog.getCurrentDatabase + private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( @@ -124,7 +123,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { val options = table.storage.serdeProperties val dataSource = DataSource( - hive, + sparkSession, userSpecifiedSchema = userSpecifiedSchema, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -179,12 +178,12 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText)) - case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText)) + case None => SubqueryAlias(table.identifier.table, sparkSession.parseSql(viewText)) + case Some(aliasText) => SubqueryAlias(aliasText, sparkSession.parseSql(viewText)) } } else { MetastoreRelation( - qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, hive) + qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, sparkSession) } } @@ -275,19 +274,20 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { val hadoopFsRelation = cached.getOrElse { val paths = new Path(metastoreRelation.catalogTable.storage.locationUri.get) :: Nil - val fileCatalog = new MetaStoreFileCatalog(hive, paths, partitionSpec) + val fileCatalog = new MetaStoreFileCatalog(sparkSession, paths, partitionSpec) val inferredSchema = if (fileType.equals("parquet")) { - val inferredSchema = defaultSource.inferSchema(hive, options, fileCatalog.allFiles()) + val inferredSchema = + defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) inferredSchema.map { inferred => ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred) }.getOrElse(metastoreSchema) } else { - defaultSource.inferSchema(hive, options, fileCatalog.allFiles()).get + defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get } val relation = HadoopFsRelation( - sqlContext = hive, + sparkSession = sparkSession, location = fileCatalog, partitionSchema = partitionSchema, dataSchema = inferredSchema, @@ -314,7 +314,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { val created = LogicalRelation( DataSource( - sqlContext = hive, + sparkSession = sparkSession, paths = paths, userSpecifiedSchema = Some(metastoreRelation.schema), bucketSpec = bucketSpec, @@ -436,7 +436,8 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p - case CreateViewCommand(table, child, allowExisting, replace, sql) if !conf.nativeView => + case CreateViewCommand(table, child, allowExisting, replace, sql) + if !sessionState.conf.nativeView => HiveNativeCommand(sql) case p @ CreateTableAsSelectLogicalPlan(table, child, allowExisting) => @@ -462,7 +463,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( TableIdentifier(desc.identifier.table), - conf.defaultDataSourceName, + sessionState.conf.defaultDataSourceName, temporary = false, Array.empty[String], bucketSpec = None, @@ -538,13 +539,17 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { * the information from the metastore. */ private[hive] class MetaStoreFileCatalog( - ctx: SQLContext, + sparkSession: SparkSession, paths: Seq[Path], partitionSpecFromHive: PartitionSpec) - extends HDFSFileCatalog(ctx, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { + extends HDFSFileCatalog( + sparkSession, + Map.empty, + paths, + Some(partitionSpecFromHive.partitionColumns)) { override def getStatus(path: Path): Array[FileStatus] = { - val fs = path.getFileSystem(ctx.sessionState.hadoopConf) + val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf) fs.listStatus(path) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9e527073d4..f70131ec86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.{AnalysisException, SQLContext} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, - context: SQLContext, + sparkSession: SparkSession, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -82,7 +82,7 @@ private[sql] class HiveSessionCatalog( // essentially a cache for metastore tables. However, it relies on a lot of session-specific // things so it would be a lot of work to split its functionality between HiveSessionCatalog // and HiveCatalog. We should still do it at some point... - private val metastoreCatalog = new HiveMetastoreCatalog(context) + private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index bf0288c9f7..4a8978e553 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -33,11 +33,14 @@ import org.apache.spark.sql.internal.SessionState /** * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. */ -private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) { +private[hive] class HiveSessionState(sparkSession: SparkSession) + extends SessionState(sparkSession) { self => - private lazy val sharedState: HiveSharedState = ctx.sharedState.asInstanceOf[HiveSharedState] + private lazy val sharedState: HiveSharedState = { + sparkSession.sharedState.asInstanceOf[HiveSharedState] + } /** * A Hive client used for execution. @@ -72,8 +75,8 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) new HiveSessionCatalog( sharedState.externalCatalog, metadataHive, - ctx, - ctx.sessionState.functionResourceLoader, + sparkSession, + functionResourceLoader, functionRegistry, conf, hiveconf) @@ -91,7 +94,7 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) catalog.PreInsertionCasts :: PreInsertCastAndRename :: DataSourceAnalysis :: - (if (conf.runSQLonFile) new ResolveDataSource(ctx) :: Nil else Nil) + (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog)) } @@ -101,9 +104,9 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) * Planner that takes into account Hive-specific strategies. */ override def planner: SparkPlanner = { - new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) with HiveStrategies { - override val context: SQLContext = ctx + override val sparkSession: SparkSession = self.sparkSession override val hiveconf: HiveConf = self.hiveconf override def strategies: Seq[Strategy] = { @@ -225,7 +228,7 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) // TODO: why do we get this from SparkConf but not SQLConf? def hiveThriftServerSingleSession: Boolean = { - ctx.sparkContext.conf.getBoolean( + sparkSession.sparkContext.conf.getBoolean( "spark.sql.hive.thriftServer.singleSession", defaultValue = false) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 2bea32b144..7d1f87f390 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -33,7 +33,7 @@ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => - val context: SQLContext + val sparkSession: SparkSession val hiveconf: HiveConf object Scripts extends Strategy { @@ -78,7 +78,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScanExec(_, relation, pruningPredicates)(context, hiveconf)) :: Nil + HiveTableScanExec(_, relation, pruningPredicates)(sparkSession, hiveconf)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala index cd45706841..0520e75306 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference, Expression} @@ -42,7 +42,7 @@ private[hive] case class MetastoreRelation( alias: Option[String]) (val catalogTable: CatalogTable, @transient private val client: HiveClient, - @transient private val sqlContext: SQLContext) + @transient private val sparkSession: SparkSession) extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation { override def equals(other: Any): Boolean = other match { @@ -58,7 +58,7 @@ private[hive] case class MetastoreRelation( Objects.hashCode(databaseName, tableName, alias, output) } - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sqlContext :: Nil + override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil private def toHiveColumn(c: CatalogColumn): FieldSchema = { new FieldSchema(c.name, c.dataType, c.comment.orNull) @@ -124,7 +124,7 @@ private[hive] case class MetastoreRelation( // if the size is still less than zero, we use default size Option(totalSize).map(_.toLong).filter(_ > 0) .getOrElse(Option(rawDataSize).map(_.toLong).filter(_ > 0) - .getOrElse(sqlContext.conf.defaultSizeInBytes))) + .getOrElse(sparkSession.sessionState.conf.defaultSizeInBytes))) } ) @@ -133,7 +133,7 @@ private[hive] case class MetastoreRelation( private lazy val allPartitions: Seq[CatalogTablePartition] = client.getAllPartitions(catalogTable) def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { + val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { client.getPartitionsByFilter(catalogTable, predicates) } else { allPartitions @@ -226,6 +226,6 @@ private[hive] case class MetastoreRelation( } override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName, alias)(catalogTable, client, sqlContext) + MetastoreRelation(databaseName, tableName, alias)(catalogTable, client, sparkSession) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e95069e830..af0317f7a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -36,7 +36,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -61,7 +61,7 @@ private[hive] class HadoopTableReader( @transient private val attributes: Seq[Attribute], @transient private val relation: MetastoreRelation, - @transient private val sc: SQLContext, + @transient private val sparkSession: SparkSession, hiveconf: HiveConf) extends TableReader with Logging { @@ -69,15 +69,15 @@ class HadoopTableReader( // https://hadoop.apache.org/docs/r1.0.4/mapred-default.html // // In order keep consistency with Hive, we will let it be 0 in local mode also. - private val _minSplitsPerRDD = if (sc.sparkContext.isLocal) { + private val _minSplitsPerRDD = if (sparkSession.sparkContext.isLocal) { 0 // will splitted based on block by default. } else { - math.max(hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) + math.max(hiveconf.getInt("mapred.map.tasks", 1), sparkSession.sparkContext.defaultMinPartitions) } - SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveconf) + SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sparkSession.sparkContext.conf, hiveconf) private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableConfiguration(hiveconf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hiveconf)) override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( @@ -153,7 +153,7 @@ class HadoopTableReader( def verifyPartitionPath( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): Map[HivePartition, Class[_ <: Deserializer]] = { - if (!sc.conf.verifyPartitionPath) { + if (!sparkSession.sessionState.conf.verifyPartitionPath) { partitionToDeserializer } else { var existPathSet = collection.mutable.Set[String]() @@ -246,7 +246,7 @@ class HadoopTableReader( // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[InternalRow](sc.sparkContext) + new EmptyRDD[InternalRow](sparkSession.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -278,7 +278,7 @@ class HadoopTableReader( val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ val rdd = new HadoopRDD( - sc.sparkContext, + sparkSession.sparkContext, _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 9240f9c7d2..08d4b99d30 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.command.RunnableCommand @@ -42,7 +42,7 @@ case class CreateTableAsSelect( override def children: Seq[LogicalPlan] = Seq(query) - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { lazy val metastoreRelation: MetastoreRelation = { import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -68,24 +68,24 @@ case class CreateTableAsSelect( withFormat } - sqlContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) + sparkSession.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation - sqlContext.sessionState.catalog.lookupRelation(tableIdentifier) match { + sparkSession.sessionState.catalog.lookupRelation(tableIdentifier) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (sqlContext.sessionState.catalog.tableExists(tableIdentifier)) { + if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { throw new AnalysisException(s"$tableIdentifier already exists.") } } else { - sqlContext.executePlan(InsertIntoTable( + sparkSession.executePlan(InsertIntoTable( metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 0f72091096..cc5bbf59db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ @@ -48,8 +48,8 @@ case class HiveTableScanExec( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Seq[Expression])( - @transient val context: SQLContext, - @transient val hiveconf: HiveConf) + @transient private val sparkSession: SparkSession, + @transient private val hiveconf: HiveConf) extends LeafExecNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, @@ -84,7 +84,7 @@ case class HiveTableScanExec( @transient private[this] val hadoopReader = - new HadoopTableReader(attributes, relation, context, hiveExtraConf) + new HadoopTableReader(attributes, relation, sparkSession, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1095f5fd95..cb49fc910b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{HadoopRDD, RDD} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -52,17 +52,17 @@ private[sql] class DefaultSource override def toString: String = "ORC" override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { OrcFileOperator.readSchema( files.map(_.getPath.toUri.toString), - Some(new Configuration(sqlContext.sessionState.hadoopConf)) + Some(new Configuration(sparkSession.sessionState.hadoopConf)) ) } override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -109,15 +109,15 @@ private[sql] class DefaultSource } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { - val orcConf = new Configuration(sqlContext.sessionState.hadoopConf) + val orcConf = new Configuration(sparkSession.sessionState.hadoopConf) - if (sqlContext.conf.orcFilterPushDown) { + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(filters.toArray).foreach { f => orcConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) @@ -125,7 +125,8 @@ private[sql] class DefaultSource } } - val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(orcConf)) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(orcConf)) (file: PartitionedFile) => { val conf = broadcastedConf.value.value @@ -270,7 +271,7 @@ private[orc] class OrcOutputWriter( } private[orc] case class OrcTableScan( - @transient sqlContext: SQLContext, + @transient sparkSession: SparkSession, attributes: Seq[Attribute], filters: Array[Filter], @transient inputPaths: Seq[FileStatus]) @@ -278,11 +279,11 @@ private[orc] case class OrcTableScan( with HiveInspectors { def execute(): RDD[InternalRow] = { - val job = Job.getInstance(new Configuration(sqlContext.sessionState.hadoopConf)) + val job = Job.getInstance(new Configuration(sparkSession.sessionState.hadoopConf)) val conf = job.getConfiguration // Tries to push down filters if ORC filter push-down is enabled - if (sqlContext.conf.orcFilterPushDown) { + if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(filters).foreach { f => conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) @@ -294,14 +295,14 @@ private[orc] case class OrcTableScan( val orcFormat = new DefaultSource val dataSchema = orcFormat - .inferSchema(sqlContext, Map.empty, inputPaths) + .inferSchema(sparkSession, Map.empty, inputPaths) .getOrElse(sys.error("Failed to read schema from target ORC files.")) // Sets requested columns OrcRelation.setRequiredColumns(conf, dataSchema, StructType.fromAttributes(attributes)) if (inputPaths.isEmpty) { // the input path probably be pruned, return an empty RDD. - return sqlContext.sparkContext.emptyRDD[InternalRow] + return sparkSession.sparkContext.emptyRDD[InternalRow] } FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) @@ -309,7 +310,7 @@ private[orc] case class OrcTableScan( classOf[OrcInputFormat] .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] - val rdd = sqlContext.sparkContext.hadoopRDD( + val rdd = sparkSession.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], inputFormatClass, classOf[NullWritable], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 04b2494043..f74e5cd6f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -71,7 +71,9 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootContext: Boolean) +class TestHiveContext( + @transient override val sparkSession: TestHiveSparkSession, + isRootContext: Boolean) extends SQLContext(sparkSession, isRootContext) { def this(sc: SparkContext) { @@ -106,11 +108,11 @@ class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootC private[hive] class TestHiveSparkSession( - sc: SparkContext, + @transient private val sc: SparkContext, val warehousePath: File, scratchDirPath: File, metastoreTemporaryConf: Map[String, String], - existingSharedState: Option[TestHiveSharedState]) + @transient private val existingSharedState: Option[TestHiveSharedState]) extends SparkSession(sc) with Logging { self => def this(sc: SparkContext) { @@ -463,7 +465,7 @@ private[hive] class TestHiveSparkSession( private[hive] class TestHiveQueryExecution( sparkSession: TestHiveSparkSession, logicalPlan: LogicalPlan) - extends QueryExecution(new SQLContext(sparkSession), logicalPlan) with Logging { + extends QueryExecution(sparkSession, logicalPlan) with Logging { def this(sparkSession: TestHiveSparkSession, sql: String) { this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql)) @@ -525,7 +527,7 @@ private[hive] class TestHiveSharedState( private[hive] class TestHiveSessionState(sparkSession: TestHiveSparkSession) - extends HiveSessionState(new SQLContext(sparkSession)) { + extends HiveSessionState(sparkSession) { override lazy val conf: SQLConf = { new SQLConf { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala index b121600dae..27c9e992de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala @@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext, plan)) + checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 5965cdc81c..7cd01c9104 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -701,7 +701,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Manually create a metastore data source table. CreateDataSourceTableUtils.createDataSourceTable( - sqlContext = sqlContext, + sparkSession = sqlContext.sparkSession, tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -910,7 +910,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) CreateDataSourceTableUtils.createDataSourceTable( - sqlContext = sqlContext, + sparkSession = sqlContext.sparkSession, tableIdent = TableIdentifier("not_skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -925,7 +925,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .forall(column => DataTypeParser.parse(column.dataType) == StringType)) CreateDataSourceTableUtils.createDataSourceTable( - sqlContext = sqlContext, + sparkSession = sqlContext.sparkSession, tableIdent = TableIdentifier("skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index bc87d3ef38..b16c9c133b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -975,7 +975,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan) + val newActual = Dataset.ofRows(sqlContext.sparkSession, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index 4a2d190353..5a8a7f0ab5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -33,7 +33,7 @@ class CommitFailureTestSource extends SimpleTextSource { * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. */ override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index eced8ed57f..e4bd1f93c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.spark.sql.{sources, Row, SQLContext} +import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -37,14 +37,14 @@ class SimpleTextSource extends FileFormat with DataSourceRegister { override def shortName(): String = "test" override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { Some(DataType.fromJson(options("dataSchema")).asInstanceOf[StructType]) } override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory { @@ -58,7 +58,7 @@ class SimpleTextSource extends FileFormat with DataSourceRegister { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, @@ -74,9 +74,9 @@ class SimpleTextSource extends FileFormat with DataSourceRegister { inputAttributes.find(_.name == field.name) } - val conf = new Configuration(sqlContext.sessionState.hadoopConf) + val conf = new Configuration(sparkSession.sessionState.hadoopConf) val broadcastedConf = - sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf)) (file: PartitionedFile) => { val predicate = { -- cgit v1.2.3