aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-25 20:54:31 -0700
committerReynold Xin <rxin@databricks.com>2016-04-25 20:54:31 -0700
commit18c2c92580bdc27aa5129d9e7abda418a3633ea6 (patch)
tree5d4222dd42ea584ad8259fbe7e22f1eb9b5bdb3c
parentfa3c06987e6148975dd54b629bd9094224358175 (diff)
downloadspark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.tar.gz
spark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.tar.bz2
spark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.zip
[SPARK-14861][SQL] Replace internal usages of SQLContext with SparkSession
## What changes were proposed in this pull request? In Spark 2.0, `SparkSession` is the new thing. Internally we should stop using `SQLContext` everywhere since that's supposed to be not the main user-facing API anymore. In this patch I took care to not break any public APIs. The one place that's suspect is `o.a.s.ml.source.libsvm.DefaultSource`, but according to mengxr it's not supposed to be public so it's OK to change the underlying `FileFormat` trait. **Reviewers**: This is a big patch that may be difficult to review but the changes are actually really straightforward. If you prefer I can break it up into a few smaller patches, but it will delay the progress of this issue a little. ## How was this patch tested? No change in functionality intended. Author: Andrew Or <andrew@databricks.com> Closes #12625 from andrewor14/spark-session-refactor.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala8
-rw-r--r--project/MimaExcludes.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala112
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala50
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala56
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala56
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala47
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala19
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala29
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala12
101 files changed, 785 insertions, 715 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 39a698af15..1bd6dae7dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -199,7 +199,7 @@ final class GBTClassificationModel private[ml](
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c2b440059b..ec7fc977b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -840,8 +840,8 @@ class BinaryLogisticRegressionSummary private[classification] (
@Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
- private val sqlContext = predictions.sqlContext
- import sqlContext.implicits._
+ private val sparkSession = predictions.sparkSession
+ import sparkSession.implicits._
/**
* Returns a BinaryClassificationMetrics object.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index dfa711b243..c04ecc88ae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -181,7 +181,7 @@ final class RandomForestClassificationModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index c57ceba4a9..27c813dd61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
import org.apache.spark.sql.types.StructType
@@ -356,14 +356,14 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]].
*
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
- * @param sqlContext Used to construct local DataFrames for returning query results
+ * @param sparkSession Used to construct local DataFrames for returning query results
*/
@Since("1.6.0")
@Experimental
-sealed abstract class LDAModel private[ml] (
+sealed abstract class LDAModel protected[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
- @Since("1.6.0") @transient protected val sqlContext: SQLContext)
+ @Since("1.6.0") @transient protected[ml] val sparkSession: SparkSession)
extends Model[LDAModel] with LDAParams with Logging with MLWritable {
// NOTE to developers:
@@ -405,7 +405,7 @@ sealed abstract class LDAModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
- val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
+ val t = udf(oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
@@ -495,7 +495,7 @@ sealed abstract class LDAModel private[ml] (
case ((termIndices, termWeights), topic) =>
(topic, termIndices.toSeq, termWeights.toSeq)
}
- sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights")
+ sparkSession.createDataFrame(topics).toDF("topic", "termIndices", "termWeights")
}
@Since("1.6.0")
@@ -512,16 +512,16 @@ sealed abstract class LDAModel private[ml] (
*/
@Since("1.6.0")
@Experimental
-class LocalLDAModel private[ml] (
+class LocalLDAModel protected[ml] (
uid: String,
vocabSize: Int,
@Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
- sqlContext: SQLContext)
- extends LDAModel(uid, vocabSize, sqlContext) {
+ sparkSession: SparkSession)
+ extends LDAModel(uid, vocabSize, sparkSession) {
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
- val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
}
@@ -554,7 +554,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -565,7 +565,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
@@ -576,7 +576,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val gammaShape = data.getAs[Double](4)
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
- val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
+ val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
@@ -604,13 +604,13 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
*/
@Since("1.6.0")
@Experimental
-class DistributedLDAModel private[ml] (
+class DistributedLDAModel protected[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
private var oldLocalModelOption: Option[OldLocalLDAModel])
- extends LDAModel(uid, vocabSize, sqlContext) {
+ extends LDAModel(uid, vocabSize, sparkSession) {
override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
@@ -628,12 +628,12 @@ class DistributedLDAModel private[ml] (
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
- def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
- val copied =
- new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
+ val copied = new DistributedLDAModel(
+ uid, vocabSize, oldDistributedModel, sparkSession, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}
@@ -688,7 +688,7 @@ class DistributedLDAModel private[ml] (
@DeveloperApi
@Since("2.0.0")
def deleteCheckpointFiles(): Unit = {
- val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ val fs = FileSystem.get(sparkSession.sparkContext.hadoopConfiguration)
_checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
_checkpointFiles = Array.empty[String]
}
@@ -720,7 +720,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(
- metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
+ metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
@@ -856,9 +856,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData)
val newModel = oldModel match {
case m: OldLocalLDAModel =>
- new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
+ new LocalLDAModel(uid, m.vocabSize, m, dataset.sparkSession)
case m: OldDistributedLDAModel =>
- new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
+ new DistributedLDAModel(uid, m.vocabSize, m, dataset.sparkSession, None)
}
copyValues(newModel).setParent(this)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 922670a41b..3fbfce9d48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -230,7 +230,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
- broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
+ broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict))
}
val dictBr = broadcastDict.get
val minTf = $(minTF)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 2002d15745..400435d7a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,7 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- dataset.sqlContext.sql(realStatement)
+ dataset.sparkSession.sql(realStatement)
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index a72692960f..c49e263df0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -226,7 +226,7 @@ class Word2VecModel private[ml] (
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
- val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+ val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 728b4d1598..6e4e6a6d85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -387,7 +387,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
- import dataset.sqlContext.implicits._
+ import dataset.sparkSession.implicits._
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 741724d7a1..da51cb7800 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -186,7 +186,7 @@ final class GBTRegressionModel private[ml](
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 4c4ff278d4..8eaed8b682 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -165,7 +165,7 @@ final class RandomForestRegressionModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index dc2a6f5275..f07374ac0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
@@ -85,12 +85,12 @@ private[libsvm] class LibSVMOutputWriter(
* optionally specify options, for example:
* {{{
* // Scala
- * val df = sqlContext.read.format("libsvm")
+ * val df = spark.read.format("libsvm")
* .option("numFeatures", "780")
* .load("data/mllib/sample_libsvm_data.txt")
*
* // Java
- * DataFrame df = sqlContext.read().format("libsvm")
+ * DataFrame df = spark.read().format("libsvm")
* .option("numFeatures, "780")
* .load("data/mllib/sample_libsvm_data.txt");
* }}}
@@ -123,7 +123,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(
@@ -133,7 +133,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareRead(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = {
def computeNumFeatures(): Int = {
@@ -146,7 +146,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
throw new IOException("Multiple input paths are not supported for libsvm data.")
}
- val sc = sqlContext.sparkContext
+ val sc = sparkSession.sparkContext
val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
MLUtils.computeNumFeatures(parsed)
}
@@ -159,7 +159,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -176,7 +176,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
@@ -188,9 +188,9 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
- val broadcastedConf = sqlContext.sparkContext.broadcast(
- new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
- )
+ val broadcastedConf = sparkSession.sparkContext.broadcast(
+ new SerializableConfiguration(
+ new Configuration(sparkSession.sparkContext.hadoopConfiguration)))
(file: PartitionedFile) => {
val points =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index de563d4fad..a41d02cde7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -94,7 +94,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
- val sqlCtx = dataset.sqlContext
+ val sparkSession = dataset.sparkSession
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
@@ -102,8 +102,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
- val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+ val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
+ val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 12d6905510..f2b7badbe5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -94,7 +94,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
- val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 7dec07ea14..94d1b83ec2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
/**
@@ -61,8 +61,10 @@ private[util] sealed trait BaseReadWrite {
optionSQLContext.get
}
- /** Returns the [[SparkContext]] underlying [[sqlContext]] */
- protected final def sc: SparkContext = sqlContext.sparkContext
+ protected final def sparkSession: SparkSession = sqlContext.sparkSession
+
+ /** Returns the underlying [[SparkContext]]. */
+ protected final def sc: SparkContext = sparkSession.sparkContext
}
/**
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 27838167fd..0f8648f890 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -672,6 +672,20 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this")
) ++ Seq(
+ // SPARK-14861: Replace internal usages of SQLContext with SparkSession
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.clustering.LocalLDAModel.this"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.clustering.DistributedLDAModel.this"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.clustering.LDAModel.this"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.ml.clustering.LDAModel.sqlContext"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.sql.Dataset.this"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.sql.DataFrameReader.this")
+ ) ++ Seq(
// [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
index 953169b636..4d5afe2eb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
@@ -35,10 +35,10 @@ trait ContinuousQuery {
def name: String
/**
- * Returns the SQLContext associated with `this` query
+ * Returns the [[SparkSession]] associated with `this`.
* @since 2.0.0
*/
- def sqlContext: SQLContext
+ def sparkSession: SparkSession
/**
* Whether the query is currently active or not
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 39d04ed8c2..9e2e2d0bc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -29,16 +29,16 @@ import org.apache.spark.sql.util.ContinuousQueryListener
/**
* :: Experimental ::
* A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active
- * on a [[SQLContext]].
+ * on a [[SparkSession]].
*
* @since 2.0.0
*/
@Experimental
-class ContinuousQueryManager(sqlContext: SQLContext) {
+class ContinuousQueryManager(sparkSession: SparkSession) {
private[sql] val stateStoreCoordinator =
- StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env)
- private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
+ StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env)
+ private val listenerBus = new ContinuousQueryListenerBus(sparkSession.sparkContext.listenerBus)
private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
private val activeQueriesLock = new Object
private val awaitTerminationLock = new Object
@@ -184,7 +184,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
val analyzedPlan = df.queryExecution.analyzed
df.queryExecution.assertAnalyzed()
- if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) {
+ if (sparkSession.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) {
UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode)
}
@@ -201,7 +201,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
StreamingExecutionRelation(source, output)
}
val query = new StreamExecution(
- sqlContext,
+ sparkSession,
name,
checkpointLocation,
logicalPlan,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index f0e16eefc7..ad00966a91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -155,7 +155,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def fill(value: Double, cols: Seq[String]): DataFrame = {
- val columnEquals = df.sqlContext.sessionState.analyzer.resolver
+ val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
@@ -182,7 +182,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def fill(value: String, cols: Seq[String]): DataFrame = {
- val columnEquals = df.sqlContext.sessionState.analyzer.resolver
+ val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
@@ -355,7 +355,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case _: String => StringType
}
- val columnEquals = df.sqlContext.sessionState.analyzer.resolver
+ val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
@@ -384,7 +384,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}
}
- val columnEquals = df.sqlContext.sessionState.analyzer.resolver
+ val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
v match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 15f2344df6..b49cda3f15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -36,12 +36,12 @@ import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
* Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems,
- * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this.
+ * key-value stores, etc) or data streams. Use [[SparkSession.read]] to access this.
*
* @since 1.4.0
*/
@Experimental
-class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
+class DataFrameReader protected[sql](sparkSession: SparkSession) extends Logging {
/**
* Specifies the input data source format.
@@ -125,11 +125,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
def load(): DataFrame = {
val dataSource =
DataSource(
- sqlContext,
+ sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
- Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation()))
+ Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation()))
}
/**
@@ -151,11 +151,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
if (paths.isEmpty) {
- sqlContext.emptyDataFrame
+ sparkSession.emptyDataFrame
} else {
- sqlContext.baseRelationToDataFrame(
+ sparkSession.baseRelationToDataFrame(
DataSource.apply(
- sqlContext,
+ sparkSession,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
@@ -172,11 +172,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
def stream(): DataFrame = {
val dataSource =
DataSource(
- sqlContext,
+ sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
- Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
+ Dataset.ofRows(sparkSession, StreamingRelation(dataSource))
}
/**
@@ -271,8 +271,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
}
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
- val relation = JDBCRelation(url, table, parts, props)(sqlContext)
- sqlContext.baseRelationToDataFrame(relation)
+ val relation = JDBCRelation(url, table, parts, props)(sparkSession)
+ sparkSession.baseRelationToDataFrame(relation)
}
/**
@@ -368,7 +368,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
val parsedOptions: JSONOptions = new JSONOptions(extraOptions.toMap)
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
+ .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val schema = userSpecifiedSchema.getOrElse {
InferSchema.infer(
jsonRDD,
@@ -377,14 +377,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
}
Dataset.ofRows(
- sqlContext,
+ sparkSession,
LogicalRDD(
schema.toAttributes,
JacksonParser.parse(
jsonRDD,
schema,
columnNameOfCorruptRecord,
- parsedOptions))(sqlContext))
+ parsedOptions))(sparkSession))
}
/**
@@ -424,9 +424,9 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 1.4.0
*/
def table(tableName: String): DataFrame = {
- Dataset.ofRows(sqlContext,
- sqlContext.sessionState.catalog.lookupRelation(
- sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)))
+ Dataset.ofRows(sparkSession,
+ sparkSession.sessionState.catalog.lookupRelation(
+ sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)))
}
/**
@@ -447,14 +447,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
@scala.annotation.varargs
def text(paths: String*): Dataset[String] = {
- format("text").load(paths : _*).as[String](sqlContext.implicits.newStringEncoder)
+ format("text").load(paths : _*).as[String](sparkSession.implicits.newStringEncoder)
}
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
- private var source: String = sqlContext.conf.defaultDataSourceName
+ private var source: String = sparkSession.sessionState.conf.defaultDataSourceName
private var userSpecifiedSchema: Option[StructType] = None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 99d92b9257..c0811f6a4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -237,7 +237,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
def save(): Unit = {
assertNotBucketed()
val dataSource = DataSource(
- df.sqlContext,
+ df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec,
@@ -284,7 +284,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
new Path(userSpecified).toUri.toString
}.orElse {
val checkpointConfig: Option[String] =
- df.sqlContext.conf.getConf(
+ df.sparkSession.getConf(
SQLConf.CHECKPOINT_LOCATION,
None)
@@ -297,7 +297,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
// If offsets have already been created, we trying to resume a query.
val checkpointPath = new Path(checkpointLocation, "offsets")
- val fs = checkpointPath.getFileSystem(df.sqlContext.sessionState.hadoopConf)
+ val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.hadoopConf)
if (fs.exists(checkpointPath)) {
throw new AnalysisException(
s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.")
@@ -306,9 +306,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
val sink = new MemorySink(df.schema)
- val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink))
+ val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink))
resultDf.registerTempTable(queryName)
- val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ val continuousQuery = df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
checkpointLocation,
df,
@@ -318,16 +318,16 @@ final class DataFrameWriter private[sql](df: DataFrame) {
} else {
val dataSource =
DataSource(
- df.sqlContext,
+ df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
- new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
+ new Path(df.sparkSession.sessionState.conf.checkpointLocation, queryName).toUri.toString
})
- df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
checkpointLocation,
df,
@@ -345,7 +345,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
- insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
+ insertInto(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
@@ -363,7 +363,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
Project(inputDataCols ++ inputPartCols, df.logicalPlan)
}.getOrElse(df.logicalPlan)
- df.sqlContext.executePlan(
+ df.sparkSession.executePlan(
InsertIntoTable(
UnresolvedRelation(tableIdent),
partitions.getOrElse(Map.empty[String, Option[String]]),
@@ -413,7 +413,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
private def normalize(columnName: String, columnType: String): String = {
val validColumnNames = df.logicalPlan.output.map(_.name)
- validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName))
+ validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName))
.getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " +
s"existing columns (${validColumnNames.mkString(", ")})"))
}
@@ -444,11 +444,11 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
- saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
+ saveAsTable(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
- val tableExists = df.sqlContext.sessionState.catalog.tableExists(tableIdent)
+ val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent)
(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
@@ -468,7 +468,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
mode,
extraOptions.toMap,
df.logicalPlan)
- df.sqlContext.executePlan(cmd).toRdd
+ df.sparkSession.executePlan(cmd).toRdd
}
}
@@ -620,7 +620,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
- private var source: String = df.sqlContext.conf.defaultDataSourceName
+ private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName
private var mode: SaveMode = SaveMode.ErrorIfExists
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3c708cbf29..b3064fd531 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -46,20 +46,19 @@ import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
private[sql] object Dataset {
- def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = {
- new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]])
+ def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
+ new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
}
- def ofRows(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
- val qe = sqlContext.executePlan(logicalPlan)
+ def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
+ val qe = sparkSession.executePlan(logicalPlan)
qe.assertAnalyzed()
- new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema))
+ new Dataset[Row](sparkSession, logicalPlan, RowEncoder(qe.analyzed.schema))
}
}
@@ -90,8 +89,8 @@ private[sql] object Dataset {
* There are typically two ways to create a Dataset. The most common way is by pointing Spark
* to some files on storage systems, using the `read` function available on a `SparkSession`.
* {{{
- * val people = session.read.parquet("...").as[Person] // Scala
- * Dataset<Person> people = session.read().parquet("...").as(Encoders.bean(Person.class) // Java
+ * val people = spark.read.parquet("...").as[Person] // Scala
+ * Dataset<Person> people = spark.read().parquet("...").as(Encoders.bean(Person.class) // Java
* }}}
*
* Datasets can also be created through transformations available on existing Datasets. For example,
@@ -121,8 +120,8 @@ private[sql] object Dataset {
* A more concrete example in Scala:
* {{{
* // To create Dataset[Row] using SQLContext
- * val people = session.read.parquet("...")
- * val department = session.read.parquet("...")
+ * val people = spark.read.parquet("...")
+ * val department = spark.read.parquet("...")
*
* people.filter("age > 30")
* .join(department, people("deptId") === department("id"))
@@ -133,8 +132,8 @@ private[sql] object Dataset {
* and in Java:
* {{{
* // To create Dataset<Row> using SQLContext
- * Dataset<Row> people = session.read().parquet("...");
- * Dataset<Row> department = session.read().parquet("...");
+ * Dataset<Row> people = spark.read().parquet("...");
+ * Dataset<Row> department = spark.read().parquet("...");
*
* people.filter("age".gt(30))
* .join(department, people.col("deptId").equalTo(department("id")))
@@ -152,8 +151,8 @@ private[sql] object Dataset {
*
* @since 1.6.0
*/
-class Dataset[T] private[sql](
- @transient val sqlContext: SQLContext,
+class Dataset[T] protected[sql](
+ @transient val sparkSession: SparkSession,
@DeveloperApi @transient val queryExecution: QueryExecution,
encoder: Encoder[T])
extends Serializable {
@@ -163,8 +162,12 @@ class Dataset[T] private[sql](
// Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.
+ def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
+ this(sparkSession, sparkSession.executePlan(logicalPlan), encoder)
+ }
+
def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
- this(sqlContext, sqlContext.executePlan(logicalPlan), encoder)
+ this(sqlContext.sparkSession, logicalPlan, encoder)
}
@transient protected[sql] val logicalPlan: LogicalPlan = {
@@ -179,9 +182,9 @@ class Dataset[T] private[sql](
// For various commands (like DDL) and queries with side effects, we force query execution
// to happen right away to let these side effects take place eagerly.
case p if hasSideEffects(p) =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession)
case Union(children) if children.forall(hasSideEffects) =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sparkSession)
case _ =>
queryExecution.analyzed
}
@@ -207,8 +210,10 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag
+ def sqlContext: SQLContext = sparkSession.wrapped
+
protected[sql] def resolve(colName: String): NamedExpression = {
- queryExecution.analyzed.resolveQuoted(colName, sqlContext.sessionState.analyzer.resolver)
+ queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
.getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
@@ -217,7 +222,7 @@ class Dataset[T] private[sql](
protected[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
- queryExecution.analyzed.resolveQuoted(n.name, sqlContext.sessionState.analyzer.resolver).get
+ queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
}
}
@@ -333,7 +338,7 @@ class Dataset[T] private[sql](
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema))
+ def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema))
/**
* :: Experimental ::
@@ -353,7 +358,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan)
+ def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
/**
* Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
@@ -407,7 +412,7 @@ class Dataset[T] private[sql](
*/
def explain(extended: Boolean): Unit = {
val explain = ExplainCommand(queryExecution.logical, extended = extended)
- sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
+ sparkSession.executePlan(explain).executedPlan.executeCollect().foreach {
// scalastyle:off println
r => println(r.getString(0))
// scalastyle:on println
@@ -631,7 +636,7 @@ class Dataset[T] private[sql](
def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = {
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
- val joined = sqlContext.executePlan(
+ val joined = sparkSession.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None))
.analyzed.asInstanceOf[Join]
@@ -695,7 +700,7 @@ class Dataset[T] private[sql](
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
- if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
+ if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
return withPlan(plan)
}
@@ -747,7 +752,7 @@ class Dataset[T] private[sql](
val left = this.logicalPlan
val right = other.logicalPlan
- val joined = sqlContext.executePlan(Join(left, right, joinType =
+ val joined = sparkSession.executePlan(Join(left, right, joinType =
JoinType(joinType), Some(condition.expr)))
val leftOutput = joined.analyzed.output.take(left.output.length)
val rightOutput = joined.analyzed.output.takeRight(right.output.length)
@@ -968,7 +973,7 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
- Column(sqlContext.sessionState.sqlParser.parseExpression(expr))
+ Column(sparkSession.sessionState.sqlParser.parseExpression(expr))
}: _*)
}
@@ -987,7 +992,7 @@ class Dataset[T] private[sql](
@Experimental
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](
- sqlContext,
+ sparkSession,
Project(
c1.withInputType(
unresolvedTEncoder.deserializer,
@@ -1005,9 +1010,8 @@ class Dataset[T] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named)
- val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
-
- new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+ val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
+ new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
}
/**
@@ -1091,7 +1095,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def filter(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -1117,7 +1121,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def where(conditionExpr: String): Dataset[T] = {
- filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
+ filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -1254,7 +1258,7 @@ class Dataset[T] private[sql](
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, inputPlan)
- val executed = sqlContext.executePlan(withGroupingKey)
+ val executed = sparkSession.executePlan(withGroupingKey)
new KeyValueGroupedDataset(
encoderFor[K],
@@ -1507,7 +1511,7 @@ class Dataset[T] private[sql](
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
- sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
+ sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
}.toArray
}
@@ -1630,7 +1634,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def withColumn(colName: String, col: Column): DataFrame = {
- val resolver = sqlContext.sessionState.analyzer.resolver
+ val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
@@ -1651,7 +1655,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] by adding a column with metadata.
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
- val resolver = sqlContext.sessionState.analyzer.resolver
+ val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
@@ -1676,7 +1680,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def withColumnRenamed(existingName: String, newName: String): DataFrame = {
- val resolver = sqlContext.sessionState.analyzer.resolver
+ val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldRename = output.exists(f => resolver(f.name, existingName))
if (shouldRename) {
@@ -1713,7 +1717,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
- val resolver = sqlContext.sessionState.analyzer.resolver
+ val resolver = sparkSession.sessionState.analyzer.resolver
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
@@ -1736,7 +1740,7 @@ class Dataset[T] private[sql](
val expression = col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(
- u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u)
+ u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
@@ -1957,7 +1961,7 @@ class Dataset[T] private[sql](
@Experimental
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U](
- sqlContext,
+ sparkSession,
MapPartitions[T, U](func, logicalPlan),
implicitly[Encoder[U]])
}
@@ -2203,7 +2207,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
+ sparkSession.cacheManager.cacheQuery(this)
this
}
@@ -2225,7 +2229,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+ sparkSession.cacheManager.cacheQuery(this, None, newLevel)
this
}
@@ -2238,7 +2242,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+ sparkSession.cacheManager.tryUncacheQuery(this, blocking)
this
}
@@ -2259,7 +2263,7 @@ class Dataset[T] private[sql](
lazy val rdd: RDD[T] = {
val objectType = unresolvedTEncoder.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
- sqlContext.executePlan(deserialized).toRdd.mapPartitions { rows =>
+ sparkSession.executePlan(deserialized).toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType).asInstanceOf[T])
}
}
@@ -2286,7 +2290,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def registerTempTable(tableName: String): Unit = {
- sqlContext.registerDataFrameAsTable(toDF(), tableName)
+ sparkSession.registerDataFrameAsTable(toDF(), tableName)
}
/**
@@ -2327,8 +2331,8 @@ class Dataset[T] private[sql](
}
}
}
- import sqlContext.implicits.newStringEncoder
- sqlContext.createDataset(rdd)
+ import sparkSession.implicits.newStringEncoder
+ sparkSession.createDataset(rdd)
}
/**
@@ -2383,7 +2387,7 @@ class Dataset[T] private[sql](
* an execution.
*/
private[sql] def withNewExecutionId[U](body: => U): U = {
- SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
+ SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body)
}
/**
@@ -2398,11 +2402,11 @@ class Dataset[T] private[sql](
val start = System.nanoTime()
val result = action(df)
val end = System.nanoTime()
- sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start)
+ sparkSession.listenerManager.onSuccess(name, df.queryExecution, end - start)
result
} catch {
case e: Exception =>
- sqlContext.listenerManager.onFailure(name, df.queryExecution, e)
+ sparkSession.listenerManager.onFailure(name, df.queryExecution, e)
throw e
}
}
@@ -2415,11 +2419,11 @@ class Dataset[T] private[sql](
val start = System.nanoTime()
val result = action(ds)
val end = System.nanoTime()
- sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start)
+ sparkSession.listenerManager.onSuccess(name, ds.queryExecution, end - start)
result
} catch {
case e: Exception =>
- sqlContext.listenerManager.onFailure(name, ds.queryExecution, e)
+ sparkSession.listenerManager.onFailure(name, ds.queryExecution, e)
throw e
}
}
@@ -2440,11 +2444,11 @@ class Dataset[T] private[sql](
/** A convenient function to wrap a logical plan and produce a DataFrame. */
@inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = {
- Dataset.ofRows(sqlContext, logicalPlan)
+ Dataset.ofRows(sparkSession, logicalPlan)
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
@inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
- Dataset(sqlContext, logicalPlan)
+ Dataset(sparkSession, logicalPlan)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 05e13e66d1..3a5ea19b8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -55,7 +55,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private def logicalPlan = queryExecution.analyzed
- private def sqlContext = queryExecution.sqlContext
+ private def sparkSession = queryExecution.sparkSession
/**
* Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
@@ -79,7 +79,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
*/
def keys: Dataset[K] = {
Dataset[K](
- sqlContext,
+ sparkSession,
Distinct(
Project(groupingAttributes, logicalPlan)))
}
@@ -104,7 +104,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
*/
def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
Dataset[U](
- sqlContext,
+ sparkSession,
MapGroups(
f,
groupingAttributes,
@@ -217,10 +217,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
Alias(CreateStruct(groupingAttributes), "key")()
}
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
- val execution = new QueryExecution(sqlContext, aggregate)
+ val execution = new QueryExecution(sparkSession, aggregate)
new Dataset(
- sqlContext,
+ sparkSession,
execution,
ExpressionEncoder.tuple(unresolvedKEncoder +: encoders))
}
@@ -289,7 +289,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit val uEncoder = other.unresolvedVEncoder
Dataset[R](
- sqlContext,
+ sparkSession,
CoGroup(
f,
this.groupingAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 0ffb136c24..7ee9732fa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -42,7 +42,7 @@ class RelationalGroupedDataset protected[sql](
groupType: RelationalGroupedDataset.GroupType) {
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
- val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+ val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) {
groupingExprs ++ aggExprs
} else {
aggExprs
@@ -53,17 +53,17 @@ class RelationalGroupedDataset protected[sql](
groupType match {
case RelationalGroupedDataset.GroupByType =>
Dataset.ofRows(
- df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case RelationalGroupedDataset.RollupType =>
Dataset.ofRows(
- df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
case RelationalGroupedDataset.CubeType =>
Dataset.ofRows(
- df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
Dataset.ofRows(
- df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
+ df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}
@@ -302,7 +302,7 @@ class RelationalGroupedDataset protected[sql](
*/
def pivot(pivotColumn: String): RelationalGroupedDataset = {
// This is to prevent unintended OOM errors when the number of distinct values is large
- val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ val maxValues = df.sparkSession.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dde139608a..47c043a00d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -58,7 +58,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager
* @since 1.0.0
*/
class SQLContext private[sql](
- @transient private val sparkSession: SparkSession,
+ val sparkSession: SparkSession,
val isRootContext: Boolean)
extends Logging with Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 00256bd0be..a0f0bd3f59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -52,7 +52,8 @@ import org.apache.spark.util.Utils
*/
class SparkSession private(
@transient val sparkContext: SparkContext,
- @transient private val existingSharedState: Option[SharedState]) { self =>
+ @transient private val existingSharedState: Option[SharedState])
+ extends Serializable { self =>
def this(sc: SparkContext) {
this(sc, None)
@@ -81,9 +82,9 @@ class SparkSession private(
*/
@transient
protected[sql] lazy val sessionState: SessionState = {
- SparkSession.reflect[SessionState, SQLContext](
+ SparkSession.reflect[SessionState, SparkSession](
SparkSession.sessionStateClassName(sparkContext.conf),
- new SQLContext(self, isRootContext = false))
+ self)
}
/**
@@ -358,7 +359,7 @@ class SparkSession private(
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
- Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRDD)(wrapped))
+ Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self))
}
/**
@@ -373,7 +374,7 @@ class SparkSession private(
SQLContext.setActive(wrapped)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
- Dataset.ofRows(wrapped, LocalRelation.fromProduct(attributeSeq, data))
+ Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
}
/**
@@ -438,7 +439,7 @@ class SparkSession private(
*/
@DeveloperApi
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
- Dataset.ofRows(wrapped, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
+ Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
}
/**
@@ -458,7 +459,7 @@ class SparkSession private(
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
}
- Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRdd)(wrapped))
+ Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
/**
@@ -486,7 +487,7 @@ class SparkSession private(
val attrSeq = getSchema(beanClass)
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
- Dataset.ofRows(wrapped, LocalRelation(attrSeq, rows.toSeq))
+ Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}
/**
@@ -496,7 +497,7 @@ class SparkSession private(
* @since 2.0.0
*/
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
- Dataset.ofRows(wrapped, LogicalRelation(baseRelation))
+ Dataset.ofRows(self, LogicalRelation(baseRelation))
}
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
@@ -504,15 +505,15 @@ class SparkSession private(
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
- Dataset[T](wrapped, plan)
+ Dataset[T](self, plan)
}
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d))
- val plan = LogicalRDD(attributes, encoded)(wrapped)
- Dataset[T](wrapped, plan)
+ val plan = LogicalRDD(attributes, encoded)(self)
+ Dataset[T](self, plan)
}
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
@@ -567,7 +568,7 @@ class SparkSession private(
*/
@Experimental
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
- new Dataset(wrapped, Range(start, end, step, numPartitions), Encoders.LONG)
+ new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG)
}
/**
@@ -579,8 +580,8 @@ class SparkSession private(
schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped)
- Dataset.ofRows(wrapped, logicalPlan)
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
+ Dataset.ofRows(self, logicalPlan)
}
/**
@@ -599,8 +600,8 @@ class SparkSession private(
} else {
rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
}
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped)
- Dataset.ofRows(wrapped, logicalPlan)
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
+ Dataset.ofRows(self, logicalPlan)
}
@@ -749,7 +750,7 @@ class SparkSession private(
}
private def table(tableIdent: TableIdentifier): DataFrame = {
- Dataset.ofRows(wrapped, sessionState.catalog.lookupRelation(tableIdent))
+ Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent))
}
/**
@@ -761,7 +762,7 @@ class SparkSession private(
* @since 2.0.0
*/
def tables(): DataFrame = {
- Dataset.ofRows(wrapped, ShowTablesCommand(None, None))
+ Dataset.ofRows(self, ShowTablesCommand(None, None))
}
/**
@@ -773,7 +774,7 @@ class SparkSession private(
* @since 2.0.0
*/
def tables(databaseName: String): DataFrame = {
- Dataset.ofRows(wrapped, ShowTablesCommand(Some(databaseName), None))
+ Dataset.ofRows(self, ShowTablesCommand(Some(databaseName), None))
}
/**
@@ -820,7 +821,7 @@ class SparkSession private(
* @since 2.0.0
*/
def sql(sqlText: String): DataFrame = {
- Dataset.ofRows(wrapped, parseSql(sqlText))
+ Dataset.ofRows(self, parseSql(sqlText))
}
/**
@@ -835,7 +836,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- def read: DataFrameReader = new DataFrameReader(wrapped)
+ def read: DataFrameReader = new DataFrameReader(self)
// scalastyle:off
@@ -906,7 +907,7 @@ class SparkSession private(
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
- Dataset.ofRows(wrapped, LogicalRDD(schema.toAttributes, rowRdd)(wrapped))
+ Dataset.ofRows(self, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 124ec09efd..f601138a9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -87,15 +87,15 @@ private[sql] class CacheManager extends Logging {
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
- val sqlContext = query.sqlContext
+ val sparkSession = query.sparkSession
cachedData +=
CachedData(
planToCache,
InMemoryRelation(
- sqlContext.conf.useCompression,
- sqlContext.conf.columnBatchSize,
+ sparkSession.sessionState.conf.useCompression,
+ sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
- sqlContext.executePlan(planToCache).executedPlan,
+ sparkSession.executePlan(planToCache).executedPlan,
tableName))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 7afdf75f38..520ceaaaea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
@@ -75,15 +75,15 @@ object RDDConversions {
/** Logical plan node for scanning data from an RDD. */
private[sql] case class LogicalRDD(
output: Seq[Attribute],
- rdd: RDD[InternalRow])(sqlContext: SQLContext)
+ rdd: RDD[InternalRow])(session: SparkSession)
extends LogicalPlan with MultiInstanceRelation {
override def children: Seq[LogicalPlan] = Nil
- override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil
+ override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def newInstance(): LogicalRDD.this.type =
- LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
+ LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
override def sameResult(plan: LogicalPlan): Boolean = plan match {
case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
@@ -95,7 +95,7 @@ private[sql] case class LogicalRDD(
@transient override lazy val statistics: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes)
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
}
@@ -329,7 +329,8 @@ private[sql] object DataSourceScanExec {
val outputPartitioning = {
val bucketSpec = relation match {
// TODO: this should be closer to bucket planning.
- case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec
+ case r: HadoopFsRelation
+ if r.sparkSession.sessionState.conf.bucketingEnabled => r.bucketSpec
case _ => None
}
@@ -349,7 +350,7 @@ private[sql] object DataSourceScanExec {
relation match {
case r: HadoopFsRelation
- if r.fileFormat.supportBatch(r.sqlContext, StructType.fromAttributes(output)) =>
+ if r.fileFormat.supportBatch(r.sparkSession, StructType.fromAttributes(output)) =>
BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
case _ =>
RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index bb83676b7d..d3d83b0218 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
@@ -39,39 +39,41 @@ import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampT
* While this is not a public class, we should avoid changing the function names for the sake of
* changing them, because a lot of developers use the feature for debugging.
*/
-class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
+class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
// TODO: Move the planner an optimizer into here from SessionState.
- protected def planner = sqlContext.sessionState.planner
-
- def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
- case e: AnalysisException =>
- val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
- ae.setStackTrace(e.getStackTrace)
- throw ae
+ protected def planner = sparkSession.sessionState.planner
+
+ def assertAnalyzed(): Unit = {
+ try sparkSession.sessionState.analyzer.checkAnalysis(analyzed) catch {
+ case e: AnalysisException =>
+ val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
+ ae.setStackTrace(e.getStackTrace)
+ throw ae
+ }
}
def assertSupported(): Unit = {
- if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) {
+ if (sparkSession.sessionState.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) {
UnsupportedOperationChecker.checkForBatch(analyzed)
}
}
lazy val analyzed: LogicalPlan = {
- SQLContext.setActive(sqlContext)
- sqlContext.sessionState.analyzer.execute(logical)
+ SQLContext.setActive(sparkSession.wrapped)
+ sparkSession.sessionState.analyzer.execute(logical)
}
lazy val withCachedData: LogicalPlan = {
assertAnalyzed()
assertSupported()
- sqlContext.cacheManager.useCachedData(analyzed)
+ sparkSession.cacheManager.useCachedData(analyzed)
}
- lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData)
+ lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
- SQLContext.setActive(sqlContext)
+ SQLContext.setActive(sparkSession.wrapped)
planner.plan(ReturnAnswer(optimizedPlan)).next()
}
@@ -93,10 +95,10 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
- PlanSubqueries(sqlContext),
- EnsureRequirements(sqlContext.conf),
- CollapseCodegenStages(sqlContext.conf),
- ReuseExchange(sqlContext.conf))
+ PlanSubqueries(sparkSession),
+ EnsureRequirements(sparkSession.sessionState.conf),
+ CollapseCodegenStages(sparkSession.sessionState.conf),
+ ReuseExchange(sparkSession.sessionState.conf))
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
@@ -110,7 +112,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
- desc.run(sqlContext).map {
+ desc.run(sparkSession).map {
case Row(name: String, dataType: String, comment) =>
Seq(name, dataType,
Option(comment.asInstanceOf[String]).getOrElse(""))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 0a11b16d0e..397d66b311 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.SparkContext
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
SparkListenerSQLExecutionStart}
import org.apache.spark.util.Utils
@@ -38,21 +38,22 @@ private[sql] object SQLExecution {
* we can connect them with an execution.
*/
def withNewExecutionId[T](
- sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = {
- val sc = sqlContext.sparkContext
+ sparkSession: SparkSession,
+ queryExecution: QueryExecution)(body: => T): T = {
+ val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
if (oldExecutionId == null) {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
val r = try {
val callSite = Utils.getCallSite()
- sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
+ sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
body
} finally {
- sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
+ sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
}
} finally {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index e28e456662..861ff3cd15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -18,12 +18,10 @@
package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
-import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
-import scala.util.control.NonFatal
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala
index e6c5351106..b6f7808398 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala
@@ -21,7 +21,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
@@ -35,8 +35,8 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable}
*/
case class AnalyzeTable(tableName: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val sessionState = sqlContext.sessionState
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val sessionState = sparkSession.sessionState
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent))
@@ -77,7 +77,7 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand {
catalogTable.storage.locationUri.map { p =>
val path = new Path(p)
try {
- val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf)
calculateTableSize(fs, path)
} catch {
case NonFatal(e) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala
index 39e441f1c3..bf66ea46fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/HiveNativeCommand.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.types.StringType
@@ -29,7 +29,7 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand {
override def output: Seq[AttributeReference] =
Seq(AttributeReference("result", StringType, nullable = false)())
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.runNativeSql(sql).map(Row(_))
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.runNativeSql(sql).map(Row(_))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
index 4daf9e916a..952a0d676f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import java.util.NoSuchElementException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -43,10 +43,10 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
schema.toAttributes
}
- private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match {
+ private val (_output, runFunc): (Seq[Attribute], SparkSession => Seq[Row]) = kv match {
// Configures the deprecated "mapred.reduce.tasks" property.
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
@@ -56,14 +56,14 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
"determining the number of reducers is not supported."
throw new IllegalArgumentException(msg)
} else {
- sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value)
+ sparkSession.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value)
Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value))
}
}
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " +
s"External sort will continue to be used.")
@@ -72,7 +72,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " +
s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " +
@@ -82,7 +82,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " +
s"will be ignored. Tungsten will continue to be used.")
@@ -91,7 +91,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " +
s"will be ignored. Codegen will continue to be used.")
@@ -100,7 +100,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " +
s"will be ignored. Unsafe mode will continue to be used.")
@@ -109,7 +109,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.SORTMERGE_JOIN} is deprecated and " +
s"will be ignored. Sort merge join will continue to be used.")
@@ -118,7 +118,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
(keyValueOutput, runFunc)
case Some((SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED} is " +
s"deprecated and will be ignored. Vectorized parquet reader will be used instead.")
@@ -128,25 +128,25 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
// Configures a single property.
case Some((key, Some(value))) =>
- val runFunc = (sqlContext: SQLContext) => {
- sqlContext.setConf(key, value)
+ val runFunc = (sparkSession: SparkSession) => {
+ sparkSession.setConf(key, value)
Seq(Row(key, value))
}
(keyValueOutput, runFunc)
// (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.)
- // Queries all key-value pairs that are set in the SQLConf of the sqlContext.
+ // Queries all key-value pairs that are set in the SQLConf of the sparkSession.
case None =>
- val runFunc = (sqlContext: SQLContext) => {
- sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq
+ val runFunc = (sparkSession: SparkSession) => {
+ sparkSession.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq
}
(keyValueOutput, runFunc)
// Queries all properties along with their default values and docs that are defined in the
- // SQLConf of the sqlContext.
+ // SQLConf of the sparkSession.
case Some(("-v", None)) =>
- val runFunc = (sqlContext: SQLContext) => {
- sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) =>
+ val runFunc = (sparkSession: SparkSession) => {
+ sparkSession.sessionState.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) =>
Row(key, defaultValue, doc)
}
}
@@ -158,19 +158,21 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
// Queries the deprecated "mapred.reduce.tasks" property.
case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
logWarning(
s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
- Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString))
+ Seq(Row(
+ SQLConf.SHUFFLE_PARTITIONS.key,
+ sparkSession.sessionState.conf.numShufflePartitions.toString))
}
(keyValueOutput, runFunc)
// Queries a single property.
case Some((key, None)) =>
- val runFunc = (sqlContext: SQLContext) => {
+ val runFunc = (sparkSession: SparkSession) => {
val value =
- try sqlContext.getConf(key) catch {
+ try sparkSession.getConf(key) catch {
case _: NoSuchElementException => "<undefined>"
}
Seq(Row(key, value))
@@ -180,6 +182,6 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
override val output: Seq[Attribute] = _output
- override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext)
+ override def run(sparkSession: SparkSession): Seq[Row] = runFunc(sparkSession)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 5be5d0c2b0..c283bd61d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{Dataset, Row, SQLContext}
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -28,15 +28,15 @@ case class CacheTableCommand(
isLazy: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
plan.foreach { logicalPlan =>
- sqlContext.registerDataFrameAsTable(Dataset.ofRows(sqlContext, logicalPlan), tableName)
+ sparkSession.registerDataFrameAsTable(Dataset.ofRows(sparkSession, logicalPlan), tableName)
}
- sqlContext.cacheTable(tableName)
+ sparkSession.cacheTable(tableName)
if (!isLazy) {
// Performs eager caching
- sqlContext.table(tableName).count()
+ sparkSession.table(tableName).count()
}
Seq.empty[Row]
@@ -48,8 +48,8 @@ case class CacheTableCommand(
case class UncacheTableCommand(tableName: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.table(tableName).unpersist(blocking = false)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.table(tableName).unpersist(blocking = false)
Seq.empty[Row]
}
@@ -61,8 +61,8 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand {
*/
case object ClearCacheCommand extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.clearCache()
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.clearCache()
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 0fd7fa92a3..7bb59b7803 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.command
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Dataset, Row, SQLContext}
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier}
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
private[sql] trait RunnableCommand extends LogicalPlan with logical.Command {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
- def run(sqlContext: SQLContext): Seq[Row]
+ def run(sparkSession: SparkSession): Seq[Row]
}
/**
@@ -54,7 +54,7 @@ private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkP
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
- cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow])
+ cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
}
override def output: Seq[Attribute] = cmd.output
@@ -97,8 +97,8 @@ case class ExplainCommand(
extends RunnableCommand {
// Run through the optimizer to generate the physical plan.
- override def run(sqlContext: SQLContext): Seq[Row] = try {
- val queryExecution = sqlContext.executePlan(logicalPlan)
+ override def run(sparkSession: SparkSession): Seq[Row] = try {
+ val queryExecution = sparkSession.executePlan(logicalPlan)
val outputString =
if (codegen) {
codegenString(queryExecution.executedPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 0ef1d1d688..31900b4993 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -55,7 +55,7 @@ case class CreateDataSourceTableCommand(
managedIfNoPath: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Since we are saving metadata to metastore, we need to check if metastore supports
// the table name and database name we have for this query. MetaStoreUtils.validateName
// is the method used by Hive to check if a table name or a database name is valid for
@@ -72,7 +72,7 @@ case class CreateDataSourceTableCommand(
}
val tableName = tableIdent.unquotedString
- val sessionState = sqlContext.sessionState
+ val sessionState = sparkSession.sessionState
if (sessionState.catalog.tableExists(tableIdent)) {
if (ignoreIfExists) {
@@ -93,14 +93,14 @@ case class CreateDataSourceTableCommand(
// Create the relation to validate the arguments before writing the metadata to the metastore.
DataSource(
- sqlContext = sqlContext,
+ sparkSession = sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = provider,
bucketSpec = None,
options = optionsWithPath).resolveRelation()
CreateDataSourceTableUtils.createDataSourceTable(
- sqlContext = sqlContext,
+ sparkSession = sparkSession,
tableIdent = tableIdent,
userSpecifiedSchema = userSpecifiedSchema,
partitionColumns = Array.empty[String],
@@ -136,7 +136,7 @@ case class CreateDataSourceTableAsSelectCommand(
query: LogicalPlan)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Since we are saving metadata to metastore, we need to check if metastore supports
// the table name and database name we have for this query. MetaStoreUtils.validateName
// is the method used by Hive to check if a table name or a database name is valid for
@@ -153,7 +153,7 @@ case class CreateDataSourceTableAsSelectCommand(
}
val tableName = tableIdent.unquotedString
- val sessionState = sqlContext.sessionState
+ val sessionState = sparkSession.sessionState
var createMetastoreTable = false
var isExternal = true
val optionsWithPath =
@@ -165,7 +165,7 @@ case class CreateDataSourceTableAsSelectCommand(
}
var existingSchema = None: Option[StructType]
- if (sqlContext.sessionState.catalog.tableExists(tableIdent)) {
+ if (sparkSession.sessionState.catalog.tableExists(tableIdent)) {
// Check if we need to throw an exception or just return.
mode match {
case SaveMode.ErrorIfExists =>
@@ -180,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand(
case SaveMode.Append =>
// Check if the specified data source match the data source of the existing table.
val dataSource = DataSource(
- sqlContext = sqlContext,
+ sparkSession = sparkSession,
userSpecifiedSchema = Some(query.schema.asNullable),
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
@@ -197,7 +197,7 @@ case class CreateDataSourceTableAsSelectCommand(
throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
}
case SaveMode.Overwrite =>
- sqlContext.sql(s"DROP TABLE IF EXISTS $tableName")
+ sparkSession.sql(s"DROP TABLE IF EXISTS $tableName")
// Need to create the table again.
createMetastoreTable = true
}
@@ -206,7 +206,7 @@ case class CreateDataSourceTableAsSelectCommand(
createMetastoreTable = true
}
- val data = Dataset.ofRows(sqlContext, query)
+ val data = Dataset.ofRows(sparkSession, query)
val df = existingSchema match {
// If we are inserting into an existing table, just use the existing schema.
case Some(s) => data.selectExpr(s.fieldNames: _*)
@@ -215,7 +215,7 @@ case class CreateDataSourceTableAsSelectCommand(
// Create the relation based on the data of df.
val dataSource = DataSource(
- sqlContext,
+ sparkSession,
className = provider,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
@@ -228,7 +228,7 @@ case class CreateDataSourceTableAsSelectCommand(
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
CreateDataSourceTableUtils.createDataSourceTable(
- sqlContext = sqlContext,
+ sparkSession = sparkSession,
tableIdent = tableIdent,
userSpecifiedSchema = Some(result.schema),
partitionColumns = partitionColumns,
@@ -260,7 +260,7 @@ object CreateDataSourceTableUtils extends Logging {
}
def createDataSourceTable(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
partitionColumns: Array[String],
@@ -275,7 +275,7 @@ object CreateDataSourceTableUtils extends Logging {
// stored into a single metastore SerDe property. In this case, we split the JSON string and
// store each part as a separate SerDe property.
userSpecifiedSchema.foreach { schema =>
- val threshold = sqlContext.sessionState.conf.schemaStringLengthThreshold
+ val threshold = sparkSession.sessionState.conf.schemaStringLengthThreshold
val schemaJsonString = schema.json
// Split the JSON string.
val parts = schemaJsonString.grouped(threshold).toSeq
@@ -329,10 +329,10 @@ object CreateDataSourceTableUtils extends Logging {
CatalogTableType.MANAGED_TABLE
}
- val maybeSerDe = HiveSerDe.sourceToSerDe(provider, sqlContext.sessionState.conf)
+ val maybeSerDe = HiveSerDe.sourceToSerDe(provider, sparkSession.sessionState.conf)
val dataSource =
DataSource(
- sqlContext,
+ sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
@@ -432,7 +432,7 @@ object CreateDataSourceTableUtils extends Logging {
// specific way.
try {
logInfo(message)
- sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false)
+ sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
} catch {
case NonFatal(e) =>
val warningMessage =
@@ -440,13 +440,13 @@ object CreateDataSourceTableUtils extends Logging {
s"it into Hive metastore in Spark SQL specific format."
logWarning(warningMessage, e)
val table = newSparkSQLSpecificMetastoreTable()
- sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false)
+ sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
}
case (None, message) =>
logWarning(message)
val table = newSparkSQLSpecificMetastoreTable()
- sqlContext.sessionState.catalog.createTable(table, ignoreIfExists = false)
+ sparkSession.sessionState.catalog.createTable(table, ignoreIfExists = false)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala
index 33cc10d53a..cefe0f6e62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.StringType
@@ -38,10 +38,10 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl
AttributeReference("result", StringType, nullable = false)() :: Nil
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val databases =
- databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases())
+ databasePattern.map(catalog.listDatabases).getOrElse(catalog.listDatabases())
databases.map { d => Row(d) }
}
}
@@ -55,8 +55,8 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl
*/
case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.catalog.setCurrentDatabase(databaseName)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.catalog.setCurrentDatabase(databaseName)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 85f0066f3b..f5aa8fb6fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable}
import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog}
@@ -38,8 +38,8 @@ import org.apache.spark.sql.types._
*/
abstract class NativeDDLCommand(val sql: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.runNativeSql(sql)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.runNativeSql(sql)
}
override val output: Seq[Attribute] = {
@@ -66,8 +66,8 @@ case class CreateDatabase(
props: Map[String, String])
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
catalog.createDatabase(
CatalogDatabase(
databaseName,
@@ -104,8 +104,8 @@ case class DropDatabase(
cascade: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
Seq.empty[Row]
}
@@ -126,8 +126,8 @@ case class AlterDatabaseProperties(
props: Map[String, String])
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName)
catalog.alterDatabase(db.copy(properties = db.properties ++ props))
@@ -152,9 +152,9 @@ case class DescribeDatabase(
extended: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
val dbMetadata: CatalogDatabase =
- sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName)
+ sparkSession.sessionState.catalog.getDatabaseMetadata(databaseName)
val result =
Row("Database Name", dbMetadata.name) ::
Row("Description", dbMetadata.description) ::
@@ -193,8 +193,8 @@ case class DropTable(
ifExists: Boolean,
isView: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (!catalog.tableExists(tableName)) {
if (!ifExists) {
val objectName = if (isView) "View" else "Table"
@@ -213,7 +213,7 @@ case class DropTable(
case _ =>
})
try {
- sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString))
+ sparkSession.cacheManager.tryUncacheQuery(sparkSession.table(tableName.quotedString))
} catch {
case NonFatal(e) => log.warn(s"${e.getMessage}", e)
}
@@ -239,8 +239,8 @@ case class AlterTableSetProperties(
isView: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
DDLUtils.verifyAlterTableType(catalog, tableName, isView)
val table = catalog.getTableMetadata(tableName)
val newProperties = table.properties ++ properties
@@ -271,8 +271,8 @@ case class AlterTableUnsetProperties(
isView: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
DDLUtils.verifyAlterTableType(catalog, tableName, isView)
val table = catalog.getTableMetadata(tableName)
if (DDLUtils.isDatasourceTable(table)) {
@@ -315,8 +315,8 @@ case class AlterTableSerDeProperties(
require(serdeClassName.isDefined || serdeProperties.isDefined,
"alter table attempted to set neither serde class name nor serde properties")
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
// Do not support setting serde for datasource tables
if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) {
@@ -350,8 +350,8 @@ case class AlterTableAddPartition(
ifNotExists: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
if (DDLUtils.isDatasourceTable(table)) {
throw new AnalysisException(
@@ -381,8 +381,8 @@ case class AlterTableRenamePartition(
newPartition: TablePartitionSpec)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.catalog.renamePartitions(
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.catalog.renamePartitions(
tableName, Seq(oldPartition), Seq(newPartition))
Seq.empty[Row]
}
@@ -409,8 +409,8 @@ case class AlterTableDropPartition(
ifExists: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
if (DDLUtils.isDatasourceTable(table)) {
throw new AnalysisException(
@@ -446,8 +446,8 @@ case class AlterTableSetLocation(
location: String)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
partitionSpec match {
case Some(spec) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
index 89ccacdc73..5aa779ddeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogFunction
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionInfo}
@@ -47,8 +47,8 @@ case class CreateFunction(
isTemp: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (isTemp) {
if (databaseName.isDefined) {
throw new AnalysisException(
@@ -99,7 +99,7 @@ case class DescribeFunction(
}
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions.
functionName.toLowerCase match {
case "<>" =>
@@ -116,7 +116,7 @@ case class DescribeFunction(
Row(s"Function: case") ::
Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " +
s"When a = b, returns c; when a = d, return e; else return f") :: Nil
- case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
+ case _ => sparkSession.sessionState.functionRegistry.lookupFunction(functionName) match {
case Some(info) =>
val result =
Row(s"Function: ${info.getName}") ::
@@ -149,8 +149,8 @@ case class DropFunction(
isTemp: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (isTemp) {
if (databaseName.isDefined) {
throw new AnalysisException(
@@ -187,12 +187,12 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru
schema.toAttributes
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val dbName = db.getOrElse(sparkSession.sessionState.catalog.getCurrentDatabase)
// If pattern is not specified, we use '*', which is used to
// match any sequence of characters (including no characters).
val functionNames =
- sqlContext.sessionState.catalog
+ sparkSession.sessionState.catalog
.listFunctions(dbName, pattern.getOrElse("*"))
.map(_.unquotedString)
// The session catalog caches some persistent functions in the FunctionRegistry
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
index fc7ecb11ec..29bcb30592 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
@@ -31,8 +31,8 @@ case class AddJar(path: String) extends RunnableCommand {
schema.toAttributes
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.addJar(path)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.addJar(path)
Seq(Row(0))
}
}
@@ -41,8 +41,8 @@ case class AddJar(path: String) extends RunnableCommand {
* Adds a file to the current session so it can be used.
*/
case class AddFile(path: String) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sparkContext.addFile(path)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sparkContext.addFile(path)
Seq.empty[Row]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 5cac9d879f..700a704941 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -22,7 +22,7 @@ import java.net.URI
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogTableType, ExternalCatalog}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
@@ -60,8 +60,8 @@ case class CreateTableLike(
sourceTable: TableIdentifier,
ifNotExists: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (!catalog.tableExists(sourceTable)) {
throw new AnalysisException(
s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'")
@@ -109,8 +109,8 @@ case class CreateTableLike(
*/
case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.catalog.createTable(table, ifNotExists)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ sparkSession.sessionState.catalog.createTable(table, ifNotExists)
Seq.empty[Row]
}
@@ -132,8 +132,8 @@ case class AlterTableRename(
isView: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
DDLUtils.verifyAlterTableType(catalog, oldName, isView)
catalog.invalidateTable(oldName)
catalog.renameTable(oldName, newName)
@@ -158,8 +158,8 @@ case class LoadData(
isOverwrite: Boolean,
partition: Option[ExternalCatalog.TablePartitionSpec]) extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (!catalog.tableExists(table)) {
throw new AnalysisException(
s"Table in LOAD DATA does not exist: '$table'")
@@ -210,7 +210,7 @@ case class LoadData(
// Follow Hive's behavior:
// If no schema or authority is provided with non-local inpath,
// we will use hadoop configuration "fs.default.name".
- val defaultFSConf = sqlContext.sessionState.hadoopConf.get("fs.default.name")
+ val defaultFSConf = sparkSession.sessionState.hadoopConf.get("fs.default.name")
val defaultFS = if (defaultFSConf == null) {
new URI("")
} else {
@@ -285,9 +285,9 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean)
new MetadataBuilder().putString("comment", "comment of the column").build())()
)
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
val result = new ArrayBuffer[Row]
- sqlContext.sessionState.catalog.lookupRelation(table) match {
+ sparkSession.sessionState.catalog.lookupRelation(table) match {
case catalogRelation: CatalogRelation =>
catalogRelation.catalogTable.schema.foreach { column =>
result += Row(column.name, column.dataType, column.comment.orNull)
@@ -333,10 +333,10 @@ case class ShowTablesCommand(
AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Since we need to return a Seq of rows, we will call getTables directly
- // instead of calling tables in sqlContext.
- val catalog = sqlContext.sessionState.catalog
+ // instead of calling tables in sparkSession.
+ val catalog = sparkSession.sessionState.catalog
val db = databaseName.getOrElse(catalog.getCurrentDatabase)
val tables =
tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db))
@@ -368,13 +368,13 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio
}
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val catalog = sqlContext.sessionState.catalog
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
if (catalog.isTemporaryTable(table)) {
Seq.empty[Row]
} else {
- val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table)
+ val catalogTable = sparkSession.sessionState.catalog.getTableMetadata(table)
propertyKey match {
case Some(p) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index 07cc4a9482..f42b56fdc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command
import scala.util.control.NonFatal
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.SQLBuilder
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute}
@@ -62,14 +62,14 @@ case class CreateViewCommand(
"It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.")
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// If the plan cannot be analyzed, throw an exception and don't proceed.
- val qe = sqlContext.executePlan(child)
+ val qe = sparkSession.executePlan(child)
qe.assertAnalyzed()
val analyzedPlan = qe.analyzed
require(tableDesc.schema == Nil || tableDesc.schema.length == analyzedPlan.output.length)
- val sessionState = sqlContext.sessionState
+ val sessionState = sparkSession.sessionState
if (sessionState.catalog.tableExists(tableIdentifier)) {
if (allowExisting) {
@@ -77,7 +77,7 @@ case class CreateViewCommand(
// already exists.
} else if (replace) {
// Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...`
- sessionState.catalog.alterTable(prepareTable(sqlContext, analyzedPlan))
+ sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan))
} else {
// Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already
// exists.
@@ -88,7 +88,7 @@ case class CreateViewCommand(
} else {
// Create the view if it doesn't exist.
sessionState.catalog.createTable(
- prepareTable(sqlContext, analyzedPlan), ignoreIfExists = false)
+ prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false)
}
Seq.empty[Row]
@@ -98,9 +98,9 @@ case class CreateViewCommand(
* Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize
* SQL based on the analyzed plan, and also creates the proper schema for the view.
*/
- private def prepareTable(sqlContext: SQLContext, analyzedPlan: LogicalPlan): CatalogTable = {
+ private def prepareTable(sparkSession: SparkSession, analyzedPlan: LogicalPlan): CatalogTable = {
val viewSQL: String =
- if (sqlContext.conf.canonicalView) {
+ if (sparkSession.sessionState.conf.canonicalView) {
val logicalPlan =
if (tableDesc.schema.isEmpty) {
analyzedPlan
@@ -108,7 +108,7 @@ case class CreateViewCommand(
val projectList = analyzedPlan.output.zip(tableDesc.schema).map {
case (attr, col) => Alias(attr, col.name)()
}
- sqlContext.executePlan(Project(projectList, analyzedPlan)).analyzed
+ sparkSession.executePlan(Project(projectList, analyzedPlan)).analyzed
}
new SQLBuilder(logicalPlan).toSQL
} else {
@@ -134,7 +134,7 @@ case class CreateViewCommand(
// Validate the view SQL - make sure we can parse it and analyze it.
// If we cannot analyze the generated query, there is probably a bug in SQL generation.
try {
- sqlContext.sql(viewSQL).queryExecution.assertAnalyzed()
+ sparkSession.sql(viewSQL).queryExecution.assertAnalyzed()
} catch {
case NonFatal(e) =>
throw new RuntimeException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 4e7214ce83..ef626ef5fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -59,7 +59,7 @@ import org.apache.spark.util.Utils
* @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data.
*/
case class DataSource(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
className: String,
paths: Seq[String] = Nil,
userSpecifiedSchema: Option[StructType] = None,
@@ -131,15 +131,15 @@ case class DataSource(
val allPaths = caseInsensitiveOptions.get("path")
val globbedPaths = allPaths.toSeq.flatMap { path =>
val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray
- val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None)
+ val fileCatalog: FileCatalog = new HDFSFileCatalog(sparkSession, options, globbedPaths, None)
userSpecifiedSchema.orElse {
format.inferSchema(
- sqlContext,
+ sparkSession,
caseInsensitiveOptions,
fileCatalog.allFiles())
}.getOrElse {
@@ -151,7 +151,8 @@ case class DataSource(
private def sourceSchema(): SourceInfo = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
- val (name, schema) = s.sourceSchema(sqlContext, userSpecifiedSchema, className, options)
+ val (name, schema) = s.sourceSchema(
+ sparkSession.wrapped, userSpecifiedSchema, className, options)
SourceInfo(name, schema)
case format: FileFormat =>
@@ -171,7 +172,7 @@ case class DataSource(
def createSource(metadataPath: String): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
- s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options)
+ s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options)
case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@@ -183,16 +184,16 @@ case class DataSource(
val newOptions = options.filterKeys(_ != "path") + ("basePath" -> path)
val newDataSource =
DataSource(
- sqlContext,
+ sparkSession,
paths = files,
userSpecifiedSchema = Some(sourceInfo.schema),
className = className,
options = new CaseInsensitiveMap(newOptions))
- Dataset.ofRows(sqlContext, LogicalRelation(newDataSource.resolveRelation()))
+ Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation()))
}
new FileStreamSource(
- sqlContext, metadataPath, path, sourceInfo.schema, dataFrameBuilder)
+ sparkSession, metadataPath, path, sourceInfo.schema, dataFrameBuilder)
case _ =>
throw new UnsupportedOperationException(
s"Data source $className does not support streamed reading")
@@ -202,14 +203,14 @@ case class DataSource(
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
providingClass.newInstance() match {
- case s: StreamSinkProvider => s.createSink(sqlContext, options, partitionColumns)
+ case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns)
case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
- new FileStreamSink(sqlContext, path, format)
+ new FileStreamSink(sparkSession, path, format)
case _ =>
throw new UnsupportedOperationException(
s"Data source $className does not support streamed writing")
@@ -225,7 +226,7 @@ case class DataSource(
case Seq(singlePath) =>
try {
val hdfsPath = new Path(singlePath)
- val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf)
val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir)
val res = fs.exists(metadataPath)
res
@@ -244,9 +245,9 @@ case class DataSource(
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
- dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema)
+ dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema)
case (dataSource: RelationProvider, None) =>
- dataSource.createRelation(sqlContext, caseInsensitiveOptions)
+ dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (_: RelationProvider, Some(_)) =>
@@ -257,11 +258,10 @@ case class DataSource(
case (format: FileFormat, _)
if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) =>
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
- val fileCatalog =
- new StreamFileCatalog(sqlContext, basePath)
+ val fileCatalog = new StreamFileCatalog(sparkSession, basePath)
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
- sqlContext,
+ sparkSession,
caseInsensitiveOptions,
fileCatalog.allFiles())
}.getOrElse {
@@ -271,7 +271,7 @@ case class DataSource(
}
HadoopFsRelation(
- sqlContext,
+ sparkSession,
fileCatalog,
partitionSchema = fileCatalog.partitionSpec().partitionColumns,
dataSchema = dataSchema,
@@ -284,7 +284,7 @@ case class DataSource(
val allPaths = caseInsensitiveOptions.get("path") ++ paths
val globbedPaths = allPaths.flatMap { path =>
val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val fs = hdfsPath.getFileSystem(sparkSession.sessionState.hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified)
@@ -311,11 +311,11 @@ case class DataSource(
}
val fileCatalog: FileCatalog =
- new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema)
+ new HDFSFileCatalog(sparkSession, options, globbedPaths, partitionSchema)
val dataSchema = userSpecifiedSchema.map { schema =>
val equality =
- if (sqlContext.conf.caseSensitiveAnalysis) {
+ if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
@@ -324,7 +324,7 @@ case class DataSource(
StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
}.orElse {
format.inferSchema(
- sqlContext,
+ sparkSession,
caseInsensitiveOptions,
fileCatalog.allFiles())
}.getOrElse {
@@ -334,10 +334,10 @@ case class DataSource(
}
val enrichedOptions =
- format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles())
+ format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles())
HadoopFsRelation(
- sqlContext,
+ sparkSession,
fileCatalog,
partitionSchema = fileCatalog.partitionSpec().partitionColumns,
dataSchema = dataSchema.asNullable,
@@ -363,7 +363,7 @@ case class DataSource(
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sqlContext, mode, options, data)
+ dataSource.createRelation(sparkSession.wrapped, mode, options, data)
case format: FileFormat =>
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
@@ -374,11 +374,11 @@ case class DataSource(
val path = new Path(caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
}))
- val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf)
path.makeQualified(fs.getUri, fs.getWorkingDirectory)
}
- val caseSensitive = sqlContext.conf.caseSensitiveAnalysis
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)
@@ -421,7 +421,7 @@ case class DataSource(
options,
data.logicalPlan,
mode)
- sqlContext.executePlan(plan).toRdd
+ sparkSession.executePlan(plan).toRdd
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 60238bd515..f7f68b1eb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.{Partition => RDDPartition, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileNameHolder, RDD}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.ColumnarBatch
@@ -51,10 +51,10 @@ case class PartitionedFile(
case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends RDDPartition
class FileScanRDD(
- @transient val sqlContext: SQLContext,
+ @transient private val sparkSession: SparkSession,
readFunction: (PartitionedFile) => Iterator[InternalRow],
@transient val filePartitions: Seq[FilePartition])
- extends RDD[InternalRow](sqlContext.sparkContext, Nil) {
+ extends RDD[InternalRow](sparkSession.sparkContext, Nil) {
override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = {
val iterator = new Iterator[Object] with AutoCloseable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 751daa0fe2..9e1308bed5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -74,14 +74,14 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
}
val partitionColumns =
- l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver)
+ l.resolve(files.partitionSchema, files.sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
val dataColumns =
- l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver)
+ l.resolve(files.dataSchema, files.sparkSession.sessionState.analyzer.resolver)
// Partition keys are not available in the statistics of the files.
val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty)
@@ -107,7 +107,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
val readFile = files.fileFormat.buildReader(
- sqlContext = files.sqlContext,
+ sparkSession = files.sparkSession,
dataSchema = files.dataSchema,
partitionSchema = files.partitionSchema,
requiredSchema = prunedDataSchema,
@@ -115,7 +115,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
options = files.options)
val plannedPartitions = files.bucketSpec match {
- case Some(bucketing) if files.sqlContext.conf.bucketingEnabled =>
+ case Some(bucketing) if files.sparkSession.sessionState.conf.bucketingEnabled =>
logInfo(s"Planning with ${bucketing.numBuckets} buckets")
val bucketed =
selectedPartitions.flatMap { p =>
@@ -134,9 +134,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
}
case _ =>
- val defaultMaxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes
- val openCostInBytes = files.sqlContext.conf.filesOpenCostInBytes
- val defaultParallelism = files.sqlContext.sparkContext.defaultParallelism
+ val defaultMaxSplitBytes = files.sparkSession.sessionState.conf.filesMaxPartitionBytes
+ val openCostInBytes = files.sparkSession.sessionState.conf.filesOpenCostInBytes
+ val defaultParallelism = files.sparkSession.sparkContext.defaultParallelism
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
val bytesPerCore = totalBytes / defaultParallelism
val maxSplitBytes = Math.min(defaultMaxSplitBytes,
@@ -195,7 +195,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
DataSourceScanExec.create(
readDataColumns ++ partitionColumns,
new FileScanRDD(
- files.sqlContext,
+ files.sparkSession,
readFile,
plannedPartitions),
files,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala
index 37c2c4517c..7b15e49641 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala
@@ -32,15 +32,15 @@ private[sql] case class InsertIntoDataSource(
overwrite: Boolean)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
- val data = Dataset.ofRows(sqlContext, query)
+ val data = Dataset.ofRows(sparkSession, query)
// Apply the schema of the existing table to the new data.
- val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
+ val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)
// Invalidate the cache.
- sqlContext.cacheManager.invalidateCache(logicalRelation)
+ sparkSession.cacheManager.invalidateCache(logicalRelation)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
index a636ca2f29..b2483e69a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -68,7 +68,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
override def children: Seq[LogicalPlan] = query :: Nil
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) {
val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect {
@@ -78,7 +78,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
s"cannot save to file.")
}
- val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf)
val fs = outputPath.getFileSystem(hadoopConf)
val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
@@ -111,14 +111,14 @@ private[sql] case class InsertIntoHadoopFsRelation(
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = query.output.filterNot(partitionSet.contains)
- val queryExecution = Dataset.ofRows(sqlContext, query).queryExecution
- SQLExecution.withNewExecutionId(sqlContext, queryExecution) {
+ val queryExecution = Dataset.ofRows(sparkSession, query).queryExecution
+ SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
val relation =
WriteRelation(
- sqlContext,
+ sparkSession,
dataColumns.toStructType,
qualifiedOutputPath.toString,
- fileFormat.prepareWrite(sqlContext, _, options, dataColumns.toStructType),
+ fileFormat.prepareWrite(sparkSession, _, options, dataColumns.toStructType),
bucketSpec)
val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) {
@@ -131,7 +131,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
dataColumns = dataColumns,
inputSchema = query.output,
PartitioningUtils.DEFAULT_PARTITION_NAME,
- sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES),
+ sparkSession.getConf(SQLConf.PARTITION_MAX_FILES),
isAppend)
}
@@ -140,7 +140,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
writerContainer.driverSideSetup()
try {
- sqlContext.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _)
+ sparkSession.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _)
writerContainer.commitJob()
refreshFunction()
} catch { case cause: Throwable =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index b9527db6d0..3b064a5bc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
@@ -36,9 +36,10 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
+
/** A container for all the details required when writing to a table. */
case class WriteRelation(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
path: String,
prepareJobForWrite: Job => OutputWriterFactory,
@@ -66,7 +67,7 @@ private[sql] abstract class BaseWriterContainer(
@transient private val jobContext: JobContext = job
private val speculationEnabled: Boolean =
- relation.sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false)
+ relation.sparkSession.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false)
// The following fields are initialized and used on both driver and executor side.
@transient protected var outputCommitter: OutputCommitter = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 7d407a7747..fb047ff867 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
@@ -49,14 +49,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
val csvOptions = new CSVOptions(options)
// TODO: Move filtering.
val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
- val rdd = baseRdd(sqlContext, csvOptions, paths)
+ val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine)
@@ -66,7 +66,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
}
- val parsedRdd = tokenRdd(sqlContext, csvOptions, header, paths)
+ val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue)
} else {
@@ -80,7 +80,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
@@ -103,8 +103,8 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val csvOptions = new CSVOptions(options)
val headers = requiredSchema.fields.map(_.name)
- val conf = new Configuration(sqlContext.sessionState.hadoopConf)
- val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+ val conf = new Configuration(sparkSession.sessionState.hadoopConf)
+ val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf))
(file: PartitionedFile) => {
val lineIterator = {
@@ -134,18 +134,18 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
private def baseRdd(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): RDD[String] = {
- readText(sqlContext, options, inputPaths.mkString(","))
+ readText(sparkSession, options, inputPaths.mkString(","))
}
private def tokenRdd(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: CSVOptions,
header: Array[String],
inputPaths: Seq[String]): RDD[Array[String]] = {
- val rdd = baseRdd(sqlContext, options, inputPaths)
+ val rdd = baseRdd(sparkSession, options, inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
CSVRelation.univocityTokenizer(rdd, header, firstLine, options)
@@ -168,14 +168,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
private def readText(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: CSVOptions,
location: String): RDD[String] = {
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
- sqlContext.sparkContext.textFile(location)
+ sparkSession.sparkContext.textFile(location)
} else {
val charset = options.charset
- sqlContext.sparkContext
+ sparkSession.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index e7e94bbef8..7d0a3d9756 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -74,15 +74,15 @@ case class CreateTempTableUsing(
s"Temporary table '$tableIdent' should not have specified a database")
}
- def run(sqlContext: SQLContext): Seq[Row] = {
+ def run(sparkSession: SparkSession): Seq[Row] = {
val dataSource = DataSource(
- sqlContext,
+ sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
className = provider,
options = options)
- sqlContext.sessionState.catalog.createTempTable(
+ sparkSession.sessionState.catalog.createTempTable(
tableIdent.table,
- Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
+ Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
overrideIfExists = true)
Seq.empty[Row]
@@ -102,18 +102,18 @@ case class CreateTempTableUsingAsSelect(
s"Temporary table '$tableIdent' should not have specified a database")
}
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val df = Dataset.ofRows(sqlContext, query)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val df = Dataset.ofRows(sparkSession, query)
val dataSource = DataSource(
- sqlContext,
+ sparkSession,
className = provider,
partitionColumns = partitionColumns,
bucketSpec = None,
options = options)
val result = dataSource.write(mode, df)
- sqlContext.sessionState.catalog.createTempTable(
+ sparkSession.sessionState.catalog.createTempTable(
tableIdent.table,
- Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan,
+ Dataset.ofRows(sparkSession, LogicalRelation(result)).logicalPlan,
overrideIfExists = true)
Seq.empty[Row]
@@ -123,23 +123,23 @@ case class CreateTempTableUsingAsSelect(
case class RefreshTable(tableIdent: TableIdentifier)
extends RunnableCommand {
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
// Refresh the given table's metadata first.
- sqlContext.sessionState.catalog.refreshTable(tableIdent)
+ sparkSession.sessionState.catalog.refreshTable(tableIdent)
// If this table is cached as a InMemoryColumnarRelation, drop the original
// cached version and make the new version cached lazily.
- val logicalPlan = sqlContext.sessionState.catalog.lookupRelation(tableIdent)
+ val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent)
// Use lookupCachedData directly since RefreshTable also takes databaseName.
- val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty
+ val isCached = sparkSession.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
// Create a data frame to represent the table.
// TODO: Use uncacheTable once it supports database name.
- val df = Dataset.ofRows(sqlContext, logicalPlan)
+ val df = Dataset.ofRows(sparkSession, logicalPlan)
// Uncache the logicalPlan.
- sqlContext.cacheManager.tryUncacheQuery(df, blocking = true)
+ sparkSession.cacheManager.tryUncacheQuery(df, blocking = true)
// Cache it again.
- sqlContext.cacheManager.cacheQuery(df, Some(tableIdent.table))
+ sparkSession.cacheManager.cacheQuery(df, Some(tableIdent.table))
}
Seq.empty[Row]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 731b0047e5..2628788ad3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -119,7 +119,7 @@ abstract class OutputWriter {
* @param options Configuration used when reading / writing data.
*/
case class HadoopFsRelation(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
location: FileCatalog,
partitionSchema: StructType,
dataSchema: StructType,
@@ -127,6 +127,8 @@ case class HadoopFsRelation(
fileFormat: FileFormat,
options: Map[String, String]) extends BaseRelation with FileRelation {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
StructType(dataSchema ++ partitionSchema.filterNot { column =>
@@ -160,7 +162,7 @@ trait FileFormat {
* Spark will require that user specify the schema manually.
*/
def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType]
@@ -169,7 +171,7 @@ trait FileFormat {
* can be useful for collecting necessary global information for scanning input data.
*/
def prepareRead(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = options
@@ -179,7 +181,7 @@ trait FileFormat {
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
*/
def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory
@@ -189,7 +191,7 @@ trait FileFormat {
*
* TODO: we should just have different traits for the different formats.
*/
- def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = {
+ def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = {
false
}
@@ -210,7 +212,7 @@ trait FileFormat {
* @return
*/
def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
@@ -265,13 +267,13 @@ trait FileCatalog {
* discovered partitions
*/
class HDFSFileCatalog(
- val sqlContext: SQLContext,
- val parameters: Map[String, String],
- val paths: Seq[Path],
- val partitionSchema: Option[StructType])
+ sparkSession: SparkSession,
+ parameters: Map[String, String],
+ override val paths: Seq[Path],
+ partitionSchema: Option[StructType])
extends FileCatalog with Logging {
- private val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf)
+ private val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf)
var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus]
var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]]
@@ -339,8 +341,8 @@ class HDFSFileCatalog(
def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path)
private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
- if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) {
- HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext)
+ if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) {
+ HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession.sparkContext)
} else {
val statuses: Seq[FileStatus] = paths.flatMap { path =>
val fs = path.getFileSystem(hadoopConf)
@@ -412,7 +414,7 @@ class HDFSFileCatalog(
PartitioningUtils.parsePartitions(
leafDirs,
PartitioningUtils.DEFAULT_PARTITION_NAME,
- typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(),
+ typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled(),
basePaths = basePaths)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
index 4dcd261f5c..6ff50a3c61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
@@ -54,6 +54,6 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
val parts = JDBCRelation.columnPartition(partitionInfo)
val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
- JDBCRelation(url, table, parts, properties)(sqlContext)
+ JDBCRelation(url, table, parts, properties)(sqlContext.sparkSession)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 9e336422d1..bcf70fdc4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -87,11 +87,13 @@ private[sql] case class JDBCRelation(
url: String,
table: String,
parts: Array[Partition],
- properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
+ properties: Properties = new Properties())(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override val needConversion: Boolean = false
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
@@ -104,7 +106,7 @@ private[sql] case class JDBCRelation(
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
- sqlContext.sparkContext,
+ sparkSession.sparkContext,
schema,
url,
properties,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 580a0e1de6..f9c34c6bb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
@@ -44,7 +44,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
override def shortName(): String = "json"
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
if (files.isEmpty) {
@@ -53,14 +53,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val parsedOptions: JSONOptions = new JSONOptions(options)
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
+ .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val jsonFiles = files.filterNot { status =>
val name = status.getPath.getName
name.startsWith("_") || name.startsWith(".")
}.toArray
val jsonSchema = InferSchema.infer(
- createBaseRdd(sqlContext, jsonFiles),
+ createBaseRdd(sparkSession, jsonFiles),
columnNameOfCorruptRecord,
parsedOptions)
checkConstraints(jsonSchema)
@@ -70,7 +70,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -92,19 +92,19 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
- val conf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val conf = new Configuration(sparkSession.sessionState.hadoopConf)
val broadcastedConf =
- sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf))
val parsedOptions: JSONOptions = new JSONOptions(options)
val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
+ .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val joinedRow = new JoinedRow()
@@ -125,8 +125,10 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = {
- val job = Job.getInstance(sqlContext.sessionState.hadoopConf)
+ private def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus]): RDD[String] = {
+ val job = Job.getInstance(sparkSession.sessionState.hadoopConf)
val conf = job.getConfiguration
val paths = inputPaths.map(_.getPath)
@@ -135,7 +137,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
FileInputFormat.setInputPaths(job, paths: _*)
}
- sqlContext.sparkContext.hadoopRDD(
+ sparkSession.sparkContext.hadoopRDD(
conf.asInstanceOf[JobConf],
classOf[TextInputFormat],
classOf[LongWritable],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 28c6664085..b156581564 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -65,12 +65,12 @@ private[sql] class DefaultSource
override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- val parquetOptions = new ParquetOptions(options, sqlContext.sessionState.conf)
+ val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf)
val conf = ContextUtil.getConfiguration(job)
@@ -110,15 +110,15 @@ private[sql] class DefaultSource
// and `CatalystWriteSupport` (writing actual rows to Parquet files).
conf.set(
SQLConf.PARQUET_BINARY_AS_STRING.key,
- sqlContext.conf.isParquetBinaryAsString.toString)
+ sparkSession.sessionState.conf.isParquetBinaryAsString.toString)
conf.set(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
- sqlContext.conf.isParquetINT96AsTimestamp.toString)
+ sparkSession.sessionState.conf.isParquetINT96AsTimestamp.toString)
conf.set(
SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key,
- sqlContext.conf.writeLegacyParquetFormat.toString)
+ sparkSession.sessionState.conf.writeLegacyParquetFormat.toString)
// Sets compression scheme
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec)
@@ -135,7 +135,7 @@ private[sql] class DefaultSource
}
def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
parameters: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
// Should we merge schemas from all Parquet part-files?
@@ -143,10 +143,10 @@ private[sql] class DefaultSource
parameters
.get(ParquetRelation.MERGE_SCHEMA)
.map(_.toBoolean)
- .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED))
+ .getOrElse(sparkSession.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED))
val mergeRespectSummaries =
- sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES)
+ sparkSession.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES)
val filesByType = splitFiles(files)
@@ -218,7 +218,7 @@ private[sql] class DefaultSource
.orElse(filesByType.data.headOption)
.toSeq
}
- ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext)
+ ParquetRelation.mergeSchemasInParallel(filesToTouch, sparkSession)
}
case class FileTypes(
@@ -249,21 +249,21 @@ private[sql] class DefaultSource
/**
* Returns whether the reader will return the rows as batch or not.
*/
- override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = {
- val conf = SQLContext.getActive().get.conf
+ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
+ val conf = sparkSession.sessionState.conf
conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled &&
schema.length <= conf.wholeStageMaxNumFields &&
schema.forall(_.dataType.isInstanceOf[AtomicType])
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
- val parquetConf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val parquetConf = new Configuration(sparkSession.sessionState.hadoopConf)
parquetConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName)
parquetConf.set(
CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
@@ -281,13 +281,13 @@ private[sql] class DefaultSource
// Sets flags for `CatalystSchemaConverter`
parquetConf.setBoolean(
SQLConf.PARQUET_BINARY_AS_STRING.key,
- sqlContext.conf.getConf(SQLConf.PARQUET_BINARY_AS_STRING))
+ sparkSession.getConf(SQLConf.PARQUET_BINARY_AS_STRING))
parquetConf.setBoolean(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
- sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP))
+ sparkSession.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP))
// Try to push down filters when filter push-down is enabled.
- val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
+ val pushed = if (sparkSession.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
@@ -299,16 +299,17 @@ private[sql] class DefaultSource
}
val broadcastedConf =
- sqlContext.sparkContext.broadcast(new SerializableConfiguration(parquetConf))
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(parquetConf))
// TODO: if you move this into the closure it reverts to the default values.
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields)
- val enableVectorizedReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
+ val enableVectorizedReader: Boolean =
+ sparkSession.sessionState.conf.parquetVectorizedReaderEnabled &&
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
// Whole stage codegen (PhysicalRDD) is able to deal with batches directly
- val returningBatch = supportBatch(sqlContext, resultSchema)
+ val returningBatch = supportBatch(sparkSession, resultSchema)
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
@@ -507,13 +508,13 @@ private[sql] object ParquetRelation extends Logging {
}
private[parquet] def readSchema(
- footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = {
+ footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = {
def parseParquetSchema(schema: MessageType): StructType = {
val converter = new CatalystSchemaConverter(
- sqlContext.conf.isParquetBinaryAsString,
- sqlContext.conf.isParquetBinaryAsString,
- sqlContext.conf.writeLegacyParquetFormat)
+ sparkSession.sessionState.conf.isParquetBinaryAsString,
+ sparkSession.sessionState.conf.isParquetBinaryAsString,
+ sparkSession.sessionState.conf.writeLegacyParquetFormat)
converter.convert(schema)
}
@@ -644,11 +645,13 @@ private[sql] object ParquetRelation extends Logging {
* S3 nodes).
*/
def mergeSchemasInParallel(
- filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = {
- val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
- val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
- val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat
- val serializedConf = new SerializableConfiguration(sqlContext.sessionState.hadoopConf)
+ filesToTouch: Seq[FileStatus],
+ sparkSession: SparkSession): Option[StructType] = {
+ val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString
+ val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp
+ val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat
+ val serializedConf =
+ new SerializableConfiguration(sparkSession.sessionState.hadoopConf)
// !! HACK ALERT !!
//
@@ -665,7 +668,7 @@ private[sql] object ParquetRelation extends Logging {
// Issues a Spark job to read Parquet schema in parallel.
val partiallyMergedSchemas =
- sqlContext
+ sparkSession
.sparkContext
.parallelize(partialFileStatusInfo)
.mapPartitions { iterator =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 5b8dc4a3ee..b622f85941 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
+import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
@@ -30,12 +30,12 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
/**
* Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]].
*/
-private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[LogicalPlan] {
+private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedRelation if u.tableIdentifier.database.isDefined =>
try {
val dataSource = DataSource(
- sqlContext,
+ sparkSession,
paths = u.tableIdentifier.table :: Nil,
className = u.tableIdentifier.database.get)
val plan = LogicalRelation(dataSource.resolveRelation())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index f7ac1ac8e4..a0d680c708 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
@@ -52,12 +52,12 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType))
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -84,15 +84,15 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
- val conf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val conf = new Configuration(sparkSession.sessionState.hadoopConf)
val broadcastedConf =
- sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf))
file => {
val unsafeRow = new UnsafeRow(1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 8c2231335c..34bd243d58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging {
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
- Dataset.ofRows(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
+ Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index d603f63a08..9c0406168e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging {
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- Dataset.ofRows(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+ Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index a86108862f..61c9c88cb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -22,7 +22,7 @@ import java.util.UUID
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.datasources.FileFormat
object FileStreamSink {
@@ -38,14 +38,14 @@ object FileStreamSink {
* in the log.
*/
class FileStreamSink(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
path: String,
fileFormat: FileFormat) extends Sink with Logging {
private val basePath = new Path(path)
private val logPath = new Path(basePath, FileStreamSink.metadataDir)
- private val fileLog = new FileStreamSinkLog(sqlContext, logPath.toUri.toString)
- private val fs = basePath.getFileSystem(sqlContext.sessionState.hadoopConf)
+ private val fileLog = new FileStreamSinkLog(sparkSession, logPath.toUri.toString)
+ private val fs = basePath.getFileSystem(sparkSession.sessionState.hadoopConf)
override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
@@ -73,7 +73,7 @@ class FileStreamSink(
private def writeFiles(data: DataFrame): Array[Path] = {
val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString
data.write.parquet(file)
- sqlContext.read
+ sparkSession.read
.schema(data.schema)
.parquet(file)
.inputFiles
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
index 6c5449a928..c548fbd369 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
@@ -25,7 +25,7 @@ import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
import org.json4s.jackson.Serialization.{read, write}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
/**
@@ -66,8 +66,8 @@ case class SinkFileStatus(
* When the reader uses `allFiles` to list all files, this method only returns the visible files
* (drops the deleted files).
*/
-class FileStreamSinkLog(sqlContext: SQLContext, path: String)
- extends HDFSMetadataLog[Seq[SinkFileStatus]](sqlContext, path) {
+class FileStreamSinkLog(sparkSession: SparkSession, path: String)
+ extends HDFSMetadataLog[Seq[SinkFileStatus]](sparkSession, path) {
import FileStreamSinkLog._
@@ -80,11 +80,11 @@ class FileStreamSinkLog(sqlContext: SQLContext, path: String)
* a live lock may happen if the compaction happens too frequently: one processing keeps deleting
* old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it.
*/
- private val fileCleanupDelayMs = sqlContext.getConf(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY)
+ private val fileCleanupDelayMs = sparkSession.getConf(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY)
- private val isDeletingExpiredLog = sqlContext.getConf(SQLConf.FILE_SINK_LOG_DELETION)
+ private val isDeletingExpiredLog = sparkSession.getConf(SQLConf.FILE_SINK_LOG_DELETION)
- private val compactInterval = sqlContext.getConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL)
+ private val compactInterval = sparkSession.getConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL)
require(compactInterval > 0,
s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " +
"to a positive value.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index aeb64c929c..e22a05bd3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.streaming
import scala.collection.mutable.ArrayBuffer
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.collection.OpenHashSet
@@ -32,14 +32,14 @@ import org.apache.spark.util.collection.OpenHashSet
* TODO Clean up the metadata files periodically
*/
class FileStreamSource(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
metadataPath: String,
path: String,
override val schema: StructType,
dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging {
- private val fs = new Path(path).getFileSystem(sqlContext.sessionState.hadoopConf)
- private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath)
+ private val fs = new Path(path).getFileSystem(sparkSession.sessionState.hadoopConf)
+ private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath)
private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
private val seenFiles = new OpenHashSet[String]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index dd6760d341..ddba3ccb1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.permission.FsPermission
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
/**
@@ -45,7 +45,7 @@ import org.apache.spark.sql.SQLContext
* Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing
* files in a directory always shows the latest files.
*/
-class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String)
+class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
extends MetadataLog[T]
with Logging {
@@ -65,7 +65,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String)
override def accept(path: Path): Boolean = isBatchFile(path)
}
- private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance()
+ private val serializer = new JavaSerializer(sparkSession.sparkContext.conf).newInstance()
protected def batchIdToPath(batchId: Long): Path = {
new Path(metadataPath, batchId.toString)
@@ -212,7 +212,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String)
}
private def createFileManager(): FileManager = {
- val hadoopConf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val hadoopConf = new Configuration(sparkSession.sessionState.hadoopConf)
try {
new FileContextManager(metadataPath, hadoopConf)
} catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index a1a1108447..b89144d727 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.OutputMode
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -28,20 +28,21 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner,
* plan incrementally. Possibly preserving state in between each execution.
*/
class IncrementalExecution(
- ctx: SQLContext,
+ sparkSession: SparkSession,
logicalPlan: LogicalPlan,
outputMode: OutputMode,
checkpointLocation: String,
- currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) {
+ currentBatchId: Long)
+ extends QueryExecution(sparkSession, logicalPlan) {
// TODO: make this always part of planning.
- val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil
+ val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy :: Nil
// Modified planner with stateful operations.
override def planner: SparkPlanner =
new SparkPlanner(
- sqlContext.sparkContext,
- sqlContext.conf,
+ sparkSession.sparkContext,
+ sparkSession.sessionState.conf,
stateStrategy)
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 2a1fa1ba62..ea3c73d984 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -44,13 +44,14 @@ import org.apache.spark.util.UninterruptibleThread
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
- override val sqlContext: SQLContext,
+ override val sparkSession: SparkSession,
override val name: String,
checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink,
val outputMode: OutputMode,
- val trigger: Trigger) extends ContinuousQuery with Logging {
+ val trigger: Trigger)
+ extends ContinuousQuery with Logging {
/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
@@ -108,7 +109,7 @@ class StreamExecution(
* processed and the N-1th entry indicates which offsets have been durably committed to the sink.
*/
private val offsetLog =
- new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets"))
+ new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets"))
/** Whether the query is currently active or not */
override def isActive: Boolean = state == ACTIVE
@@ -158,7 +159,7 @@ class StreamExecution(
startLatch.countDown()
// While active, repeatedly attempt to run batches.
- SQLContext.setActive(sqlContext)
+ SQLContext.setActive(sparkSession.wrapped)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
triggerExecutor.execute(() => {
@@ -181,7 +182,7 @@ class StreamExecution(
logError(s"Query $name terminated with error", e)
} finally {
state = TERMINATED
- sqlContext.streams.notifyQueryTermination(StreamExecution.this)
+ sparkSession.streams.notifyQueryTermination(StreamExecution.this)
postEvent(new QueryTerminated(this))
terminationLatch.countDown()
}
@@ -317,7 +318,7 @@ class StreamExecution(
val optimizerStart = System.nanoTime()
lastExecution = new IncrementalExecution(
- sqlContext,
+ sparkSession,
newPlan,
outputMode,
checkpointFile("state"),
@@ -328,7 +329,7 @@ class StreamExecution(
logDebug(s"Optimized batch in ${optimizerTime}ms")
val nextBatch =
- new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema))
+ new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema))
sink.addBatch(currentBatchId - 1, nextBatch)
awaitBatchLock.synchronized {
@@ -344,7 +345,7 @@ class StreamExecution(
}
private def postEvent(event: ContinuousQueryListener.Event) {
- sqlContext.streams.postListenerEvent(event)
+ sparkSession.streams.postListenerEvent(event)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala
index a08a4bb4c3..b2bc31634c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala
@@ -20,17 +20,17 @@ package org.apache.spark.sql.execution.streaming
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.{FileCatalog, Partition, PartitionSpec}
import org.apache.spark.sql.types.StructType
-class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog with Logging {
+class StreamFileCatalog(sparkSession: SparkSession, path: Path) extends FileCatalog with Logging {
val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
logInfo(s"Reading streaming file log from $metadataDirectory")
- val metadataLog = new FileStreamSinkLog(sqlContext, metadataDirectory.toUri.toString)
- val fs = path.getFileSystem(sqlContext.sessionState.hadoopConf)
+ val metadataLog = new FileStreamSinkLog(sparkSession, metadataDirectory.toUri.toString)
+ val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf)
override def paths: Seq[Path] = path :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 3820968324..0d2a6dd929 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -58,11 +58,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def schema: StructType = encoder.schema
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
- Dataset(sqlContext, logicalPlan)
+ Dataset(sqlContext.sparkSession, logicalPlan)
}
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
- Dataset.ofRows(sqlContext, logicalPlan)
+ Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}
def addData(data: A*): Offset = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 1341e45483..4a1f12d685 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
@@ -70,11 +70,11 @@ case class ScalarSubquery(
/**
* Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
-case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
+case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
- val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan
+ val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f2448af991..fe63c80815 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -932,7 +932,7 @@ object functions {
* @since 1.5.0
*/
def broadcast(df: DataFrame): DataFrame = {
- Dataset.ofRows(df.sqlContext, BroadcastHint(df.logicalPlan))
+ Dataset.ofRows(df.sparkSession, BroadcastHint(df.logicalPlan))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 04ad729659..1bda572e63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -37,9 +37,9 @@ import org.apache.spark.sql.util.ExecutionListenerManager
/**
- * A class that holds all session-specific state in a given [[SQLContext]].
+ * A class that holds all session-specific state in a given [[SparkSession]].
*/
-private[sql] class SessionState(ctx: SQLContext) {
+private[sql] class SessionState(sparkSession: SparkSession) {
// Note: These are all lazy vals because they depend on each other (e.g. conf) and we
// want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
@@ -48,10 +48,12 @@ private[sql] class SessionState(ctx: SQLContext) {
* SQL-specific key-value configurations.
*/
lazy val conf: SQLConf = new SQLConf
- lazy val hadoopConf: Configuration = new Configuration(ctx.sparkContext.hadoopConfiguration)
+ lazy val hadoopConf: Configuration = {
+ new Configuration(sparkSession.sparkContext.hadoopConfiguration)
+ }
// Automatically extract `spark.sql.*` entries and put it in our SQLConf
- setConf(SQLContext.getSQLProperties(ctx.sparkContext.getConf))
+ setConf(SQLContext.getSQLProperties(sparkSession.sparkContext.getConf))
lazy val experimentalMethods = new ExperimentalMethods
@@ -68,7 +70,7 @@ private[sql] class SessionState(ctx: SQLContext) {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
- case FileResource => ctx.sparkContext.addFile(resource.uri)
+ case FileResource => sparkSession.sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
@@ -82,10 +84,10 @@ private[sql] class SessionState(ctx: SQLContext) {
* Internal catalog for managing table and database states.
*/
lazy val catalog = new SessionCatalog(
- ctx.externalCatalog,
- functionResourceLoader,
- functionRegistry,
- conf)
+ sparkSession.externalCatalog,
+ functionResourceLoader,
+ functionRegistry,
+ conf)
/**
* Interface exposed to the user for registering user-defined functions.
@@ -100,7 +102,7 @@ private[sql] class SessionState(ctx: SQLContext) {
override val extendedResolutionRules =
PreInsertCastAndRename ::
DataSourceAnalysis ::
- (if (conf.runSQLonFile) new ResolveDataSource(ctx) :: Nil else Nil)
+ (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog))
}
@@ -120,7 +122,7 @@ private[sql] class SessionState(ctx: SQLContext) {
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
- new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -131,14 +133,16 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s.
*/
- lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx)
+ lazy val continuousQueryManager: ContinuousQueryManager = {
+ new ContinuousQueryManager(sparkSession)
+ }
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(ctx, plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
@@ -162,7 +166,7 @@ private[sql] class SessionState(ctx: SQLContext) {
}
def addJar(path: String): Unit = {
- ctx.sparkContext.addJar(path)
+ sparkSession.sparkContext.addJar(path)
}
/**
@@ -173,7 +177,7 @@ private[sql] class SessionState(ctx: SQLContext) {
* in the external catalog.
*/
def analyze(tableName: String): Unit = {
- AnalyzeTable(tableName).run(ctx)
+ AnalyzeTable(tableName).run(sparkSession)
}
def runNativeSql(sql: String): Seq[String] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 18e04c24a4..47b55e2547 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -57,7 +57,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
rows(0))
// dropna on an a dataframe with no column should return an empty data frame.
- val empty = input.sqlContext.emptyDataFrame.select()
+ val empty = input.sparkSession.emptyDataFrame.select()
assert(empty.na.drop().count() === 0L)
// Make sure the columns are properly named.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4c18784126..681476b6e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -464,8 +464,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("callUDF in SQLContext") {
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
- val sqlctx = df.sqlContext
- sqlctx.udf.register("simpleUDF", (v: Int) => v * v)
+ df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
checkAnswer(
df.select($"id", callUDF("simpleUDF", $"value")),
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
@@ -618,7 +617,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
test("apply on query results (SPARK-5462)") {
- val df = testData.sqlContext.sql("select key from testData")
+ val df = testData.sparkSession.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
}
@@ -975,7 +974,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into an OneRowRelation
- Dataset.ofRows(sqlContext, OneRowRelation).registerTempTable("one_row")
+ Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d9b374b792..df8b3b7d87 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -282,7 +282,7 @@ abstract class QueryTest extends PlanTest {
def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
case l: LogicalRDD =>
val origin = logicalRDDs.pop()
- LogicalRDD(l.output, origin.rdd)(sqlContext)
+ LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession)
case l: LocalRelation =>
val origin = localRelations.pop()
l.copy(data = origin.data)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index c014f61679..dff6acc94b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.util.Utils
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 38318740a5..073e0b3f00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,7 +231,7 @@ object SparkPlanTest {
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
- val execution = new QueryExecution(sqlContext, null) {
+ val execution = new QueryExecution(sqlContext.sparkSession, null) {
override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index fb70dbd961..9da0af3a76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -280,7 +280,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
))
val fakeRDD = new FileScanRDD(
- sqlContext,
+ sqlContext.sparkSession,
(file: PartitionedFile) => Iterator.empty,
Seq(partition)
)
@@ -414,7 +414,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
l.copy(relation =
r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil))))
}
- Dataset.ofRows(sqlContext, bucketed)
+ Dataset.ofRows(sqlContext.sparkSession, bucketed)
} else {
df
}
@@ -449,7 +449,7 @@ class TestFileFormat extends FileFormat {
* Spark will require that user specify the schema manually.
*/
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] =
Some(
@@ -463,7 +463,7 @@ class TestFileFormat extends FileFormat {
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
*/
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -471,7 +471,7 @@ class TestFileFormat extends FileFormat {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 1a7b62ca0a..e5588bec4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -1316,7 +1316,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = DataSource(
- sqlContext,
+ sqlContext.sparkSession,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
bucketSpec = None,
@@ -1324,7 +1324,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
options = Map("path" -> path)).resolveRelation()
val d2 = DataSource(
- sqlContext,
+ sqlContext.sparkSession,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
bucketSpec = None,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
index a164f4c733..df127d958e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
@@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = {
withTempDir { file =>
- val sinkLog = new FileStreamSinkLog(sqlContext, file.getCanonicalPath)
+ val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath)
f(sinkLog)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index 22e011cfb7..129b5a8c36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -59,7 +59,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
test("HDFSMetadataLog: basic") {
withTempDir { temp =>
val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
- val metadataLog = new HDFSMetadataLog[String](sqlContext, dir.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.getLatest() === Some(0 -> "batch0"))
assert(metadataLog.get(0) === Some("batch0"))
@@ -86,14 +86,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
s"fs.$scheme.impl",
classOf[FakeFileSystem].getName)
withTempDir { temp =>
- val metadataLog = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp")
+ val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp")
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.getLatest() === Some(0 -> "batch0"))
assert(metadataLog.get(0) === Some("batch0"))
assert(metadataLog.get(None, 0) === Array(0 -> "batch0"))
- val metadataLog2 = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp")
+ val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp")
assert(metadataLog2.get(0) === Some("batch0"))
assert(metadataLog2.getLatest() === Some(0 -> "batch0"))
assert(metadataLog2.get(None, 0) === Array(0 -> "batch0"))
@@ -103,7 +103,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
test("HDFSMetadataLog: restart") {
withTempDir { temp =>
- val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
assert(metadataLog.add(0, "batch0"))
assert(metadataLog.add(1, "batch1"))
assert(metadataLog.get(0) === Some("batch0"))
@@ -111,7 +111,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
assert(metadataLog.getLatest() === Some(1 -> "batch1"))
assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1"))
- val metadataLog2 = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath)
+ val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
assert(metadataLog2.get(0) === Some("batch0"))
assert(metadataLog2.get(1) === Some("batch1"))
assert(metadataLog2.getLatest() === Some(1 -> "batch1"))
@@ -126,7 +126,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
for (id <- 0 until 10) {
new Thread() {
override def run(): Unit = waiter {
- val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath)
+ val metadataLog =
+ new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
try {
var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
nextBatchId += 1
@@ -145,7 +146,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
}
waiter.await(timeout(10.seconds), dismissals(10))
- val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath)
+ val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath)
assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString))
assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 5f8514e1a2..612cfc7ec7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -28,13 +28,21 @@ class DDLScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext)
+ SimpleDDLScan(
+ parameters("from").toInt,
+ parameters("TO").toInt,
+ parameters("Table"))(sqlContext.sparkSession)
}
}
-case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext)
+case class SimpleDDLScan(
+ from: Int,
+ to: Int,
+ table: String)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override def schema: StructType =
StructType(Seq(
StructField("intType", IntegerType, nullable = false,
@@ -63,7 +71,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
override def buildScan(): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
- sqlContext.sparkContext.parallelize(from to to).map { e =>
+ sparkSession.sparkContext.parallelize(from to to).map { e =>
InternalRow(UTF8String.fromString(s"people$e"), e * 2)
}.asInstanceOf[RDD[Row]]
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 14707774cf..51d04f2f4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -32,14 +32,16 @@ class FilteredScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext.sparkSession)
}
}
-case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override def schema: StructType =
StructType(
StructField("a", IntegerType, nullable = false) ::
@@ -115,7 +117,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c))
}
- sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
+ sparkSession.sparkContext.parallelize(from to to).filter(eval).map(i =>
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 9bb901bfb3..cd0256db43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -29,14 +29,16 @@ class PrunedScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext.sparkSession)
}
}
-case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedScan {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override def schema: StructType =
StructType(
StructField("a", IntegerType, nullable = false) ::
@@ -48,7 +50,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
case "b" => (i: Int) => Seq(i * 2)
}
- sqlContext.sparkContext.parallelize(from to to).map(i =>
+ sparkSession.sparkContext.parallelize(from to to).map(i =>
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
index 94d032f4ee..4f6df54417 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
class ResolvedDataSourceSuite extends SparkFunSuite {
private def getProvidingClass(name: String): Class[_] =
- DataSource(sqlContext = null, className = name).providingClass
+ DataSource(sparkSession = null, className = name).providingClass
test("jdbc") {
assert(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 99f1661ad0..34b8726a92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -31,17 +31,21 @@ class SimpleScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- SimpleScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext)
+ SimpleScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext.sparkSession)
}
}
-case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override def schema: StructType =
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
- override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
+ override def buildScan(): RDD[Row] = {
+ sparkSession.sparkContext.parallelize(from to to).map(Row(_))
+ }
}
class AllDataTypesScanSource extends SchemaRelationProvider {
@@ -53,23 +57,27 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
parameters("option_with_underscores")
parameters("option.with.dots")
- AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
+ AllDataTypesScan(
+ parameters("from").toInt,
+ parameters("TO").toInt, schema)(sqlContext.sparkSession)
}
}
case class AllDataTypesScan(
from: Int,
to: Int,
- userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
+ userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession)
extends BaseRelation
with TableScan {
+ override def sqlContext: SQLContext = sparkSession.wrapped
+
override def schema: StructType = userSpecifiedSchema
override def needConversion: Boolean = true
override def buildScan(): RDD[Row] = {
- sqlContext.sparkContext.parallelize(from to to).map { i =>
+ sparkSession.sparkContext.parallelize(from to to).map { i =>
Row(
s"str_$i",
s"str_$i".getBytes(StandardCharsets.UTF_8),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index fcfac359f3..5577c9f3ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -257,7 +257,7 @@ private[sql] trait SQLTestUtils
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
- Dataset.ofRows(sqlContext, plan)
+ Dataset.ofRows(sqlContext.sparkSession, plan)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index d270775af6..9799c6d42b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.{RuntimeConfigImpl, SessionState, SQLConf}
* A special [[SQLContext]] prepared for testing.
*/
private[sql] class TestSQLContext(
- @transient private val sparkSession: SparkSession,
+ @transient override val sparkSession: SparkSession,
isRootContext: Boolean)
extends SQLContext(sparkSession, isRootContext) { self =>
@@ -57,7 +57,7 @@ private[sql] class TestSQLContext(
private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
@transient
- protected[sql] override lazy val sessionState: SessionState = new SessionState(wrapped) {
+ protected[sql] override lazy val sessionState: SessionState = new SessionState(self) {
override lazy val conf: SQLConf = {
new SQLConf {
clear()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index edb87b94ea..01b7cfbd2e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -24,7 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
+import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
@@ -45,16 +45,15 @@ import org.apache.spark.sql.types._
* This is still used for things like creating data source tables, but in the future will be
* cleaned up to integrate more nicely with [[HiveExternalCatalog]].
*/
-private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
- private val conf = hive.conf
- private val sessionState = hive.sessionState.asInstanceOf[HiveSessionState]
- private val client = hive.sharedState.asInstanceOf[HiveSharedState].metadataHive
- private val hiveconf = sessionState.hiveconf
+private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging {
+ private val conf = sparkSession.conf
+ private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
+ private val client = sparkSession.sharedState.asInstanceOf[HiveSharedState].metadataHive
/** A fully qualified identifier for a table (i.e., database.tableName) */
case class QualifiedTableName(database: String, name: String)
- private def getCurrentDatabase: String = hive.sessionState.catalog.getCurrentDatabase
+ private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase
def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = {
QualifiedTableName(
@@ -124,7 +123,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
val options = table.storage.serdeProperties
val dataSource =
DataSource(
- hive,
+ sparkSession,
userSpecifiedSchema = userSpecifiedSchema,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
@@ -179,12 +178,12 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
- case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText))
- case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText))
+ case None => SubqueryAlias(table.identifier.table, sparkSession.parseSql(viewText))
+ case Some(aliasText) => SubqueryAlias(aliasText, sparkSession.parseSql(viewText))
}
} else {
MetastoreRelation(
- qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, hive)
+ qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, sparkSession)
}
}
@@ -275,19 +274,20 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
val hadoopFsRelation = cached.getOrElse {
val paths = new Path(metastoreRelation.catalogTable.storage.locationUri.get) :: Nil
- val fileCatalog = new MetaStoreFileCatalog(hive, paths, partitionSpec)
+ val fileCatalog = new MetaStoreFileCatalog(sparkSession, paths, partitionSpec)
val inferredSchema = if (fileType.equals("parquet")) {
- val inferredSchema = defaultSource.inferSchema(hive, options, fileCatalog.allFiles())
+ val inferredSchema =
+ defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles())
inferredSchema.map { inferred =>
ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred)
}.getOrElse(metastoreSchema)
} else {
- defaultSource.inferSchema(hive, options, fileCatalog.allFiles()).get
+ defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get
}
val relation = HadoopFsRelation(
- sqlContext = hive,
+ sparkSession = sparkSession,
location = fileCatalog,
partitionSchema = partitionSchema,
dataSchema = inferredSchema,
@@ -314,7 +314,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
val created =
LogicalRelation(
DataSource(
- sqlContext = hive,
+ sparkSession = sparkSession,
paths = paths,
userSpecifiedSchema = Some(metastoreRelation.schema),
bucketSpec = bucketSpec,
@@ -436,7 +436,8 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
case p: LogicalPlan if !p.childrenResolved => p
case p: LogicalPlan if p.resolved => p
- case CreateViewCommand(table, child, allowExisting, replace, sql) if !conf.nativeView =>
+ case CreateViewCommand(table, child, allowExisting, replace, sql)
+ if !sessionState.conf.nativeView =>
HiveNativeCommand(sql)
case p @ CreateTableAsSelectLogicalPlan(table, child, allowExisting) =>
@@ -462,7 +463,7 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
CreateTableUsingAsSelect(
TableIdentifier(desc.identifier.table),
- conf.defaultDataSourceName,
+ sessionState.conf.defaultDataSourceName,
temporary = false,
Array.empty[String],
bucketSpec = None,
@@ -538,13 +539,17 @@ private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging {
* the information from the metastore.
*/
private[hive] class MetaStoreFileCatalog(
- ctx: SQLContext,
+ sparkSession: SparkSession,
paths: Seq[Path],
partitionSpecFromHive: PartitionSpec)
- extends HDFSFileCatalog(ctx, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) {
+ extends HDFSFileCatalog(
+ sparkSession,
+ Map.empty,
+ paths,
+ Some(partitionSpecFromHive.partitionColumns)) {
override def getStatus(path: Path): Array[FileStatus] = {
- val fs = path.getFileSystem(ctx.sessionState.hadoopConf)
+ val fs = path.getFileSystem(sparkSession.sessionState.hadoopConf)
fs.listStatus(path)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 9e527073d4..f70131ec86 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF}
import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry}
import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF}
-import org.apache.spark.sql.{AnalysisException, SQLContext}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
@@ -43,7 +43,7 @@ import org.apache.spark.util.Utils
private[sql] class HiveSessionCatalog(
externalCatalog: HiveExternalCatalog,
client: HiveClient,
- context: SQLContext,
+ sparkSession: SparkSession,
functionResourceLoader: FunctionResourceLoader,
functionRegistry: FunctionRegistry,
conf: SQLConf,
@@ -82,7 +82,7 @@ private[sql] class HiveSessionCatalog(
// essentially a cache for metastore tables. However, it relies on a lot of session-specific
// things so it would be a lot of work to split its functionality between HiveSessionCatalog
// and HiveCatalog. We should still do it at some point...
- private val metastoreCatalog = new HiveMetastoreCatalog(context)
+ private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index bf0288c9f7..4a8978e553 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -33,11 +33,14 @@ import org.apache.spark.sql.internal.SessionState
/**
* A class that holds all session-specific state in a given [[SparkSession]] backed by Hive.
*/
-private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) {
+private[hive] class HiveSessionState(sparkSession: SparkSession)
+ extends SessionState(sparkSession) {
self =>
- private lazy val sharedState: HiveSharedState = ctx.sharedState.asInstanceOf[HiveSharedState]
+ private lazy val sharedState: HiveSharedState = {
+ sparkSession.sharedState.asInstanceOf[HiveSharedState]
+ }
/**
* A Hive client used for execution.
@@ -72,8 +75,8 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx)
new HiveSessionCatalog(
sharedState.externalCatalog,
metadataHive,
- ctx,
- ctx.sessionState.functionResourceLoader,
+ sparkSession,
+ functionResourceLoader,
functionRegistry,
conf,
hiveconf)
@@ -91,7 +94,7 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx)
catalog.PreInsertionCasts ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
- (if (conf.runSQLonFile) new ResolveDataSource(ctx) :: Nil else Nil)
+ (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog))
}
@@ -101,9 +104,9 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx)
* Planner that takes into account Hive-specific strategies.
*/
override def planner: SparkPlanner = {
- new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
with HiveStrategies {
- override val context: SQLContext = ctx
+ override val sparkSession: SparkSession = self.sparkSession
override val hiveconf: HiveConf = self.hiveconf
override def strategies: Seq[Strategy] = {
@@ -225,7 +228,7 @@ private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx)
// TODO: why do we get this from SparkConf but not SQLConf?
def hiveThriftServerSingleSession: Boolean = {
- ctx.sparkContext.conf.getBoolean(
+ sparkSession.sparkContext.conf.getBoolean(
"spark.sql.hive.thriftServer.singleSession", defaultValue = false)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 2bea32b144..7d1f87f390 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -33,7 +33,7 @@ private[hive] trait HiveStrategies {
// Possibly being too clever with types here... or not clever enough.
self: SparkPlanner =>
- val context: SQLContext
+ val sparkSession: SparkSession
val hiveconf: HiveConf
object Scripts extends Strategy {
@@ -78,7 +78,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
- HiveTableScanExec(_, relation, pruningPredicates)(context, hiveconf)) :: Nil
+ HiveTableScanExec(_, relation, pruningPredicates)(sparkSession, hiveconf)) :: Nil
case _ =>
Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
index cd45706841..0520e75306 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.TableDesc
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference, Expression}
@@ -42,7 +42,7 @@ private[hive] case class MetastoreRelation(
alias: Option[String])
(val catalogTable: CatalogTable,
@transient private val client: HiveClient,
- @transient private val sqlContext: SQLContext)
+ @transient private val sparkSession: SparkSession)
extends LeafNode with MultiInstanceRelation with FileRelation with CatalogRelation {
override def equals(other: Any): Boolean = other match {
@@ -58,7 +58,7 @@ private[hive] case class MetastoreRelation(
Objects.hashCode(databaseName, tableName, alias, output)
}
- override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sqlContext :: Nil
+ override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil
private def toHiveColumn(c: CatalogColumn): FieldSchema = {
new FieldSchema(c.name, c.dataType, c.comment.orNull)
@@ -124,7 +124,7 @@ private[hive] case class MetastoreRelation(
// if the size is still less than zero, we use default size
Option(totalSize).map(_.toLong).filter(_ > 0)
.getOrElse(Option(rawDataSize).map(_.toLong).filter(_ > 0)
- .getOrElse(sqlContext.conf.defaultSizeInBytes)))
+ .getOrElse(sparkSession.sessionState.conf.defaultSizeInBytes)))
}
)
@@ -133,7 +133,7 @@ private[hive] case class MetastoreRelation(
private lazy val allPartitions: Seq[CatalogTablePartition] = client.getAllPartitions(catalogTable)
def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
- val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) {
+ val rawPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) {
client.getPartitionsByFilter(catalogTable, predicates)
} else {
allPartitions
@@ -226,6 +226,6 @@ private[hive] case class MetastoreRelation(
}
override def newInstance(): MetastoreRelation = {
- MetastoreRelation(databaseName, tableName, alias)(catalogTable, client, sqlContext)
+ MetastoreRelation(databaseName, tableName, alias)(catalogTable, client, sparkSession)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index e95069e830..af0317f7a1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -36,7 +36,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -61,7 +61,7 @@ private[hive]
class HadoopTableReader(
@transient private val attributes: Seq[Attribute],
@transient private val relation: MetastoreRelation,
- @transient private val sc: SQLContext,
+ @transient private val sparkSession: SparkSession,
hiveconf: HiveConf)
extends TableReader with Logging {
@@ -69,15 +69,15 @@ class HadoopTableReader(
// https://hadoop.apache.org/docs/r1.0.4/mapred-default.html
//
// In order keep consistency with Hive, we will let it be 0 in local mode also.
- private val _minSplitsPerRDD = if (sc.sparkContext.isLocal) {
+ private val _minSplitsPerRDD = if (sparkSession.sparkContext.isLocal) {
0 // will splitted based on block by default.
} else {
- math.max(hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions)
+ math.max(hiveconf.getInt("mapred.map.tasks", 1), sparkSession.sparkContext.defaultMinPartitions)
}
- SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveconf)
+ SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sparkSession.sparkContext.conf, hiveconf)
private val _broadcastedHiveConf =
- sc.sparkContext.broadcast(new SerializableConfiguration(hiveconf))
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hiveconf))
override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
@@ -153,7 +153,7 @@ class HadoopTableReader(
def verifyPartitionPath(
partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]):
Map[HivePartition, Class[_ <: Deserializer]] = {
- if (!sc.conf.verifyPartitionPath) {
+ if (!sparkSession.sessionState.conf.verifyPartitionPath) {
partitionToDeserializer
} else {
var existPathSet = collection.mutable.Set[String]()
@@ -246,7 +246,7 @@ class HadoopTableReader(
// Even if we don't use any partitions, we still need an empty RDD
if (hivePartitionRDDs.size == 0) {
- new EmptyRDD[InternalRow](sc.sparkContext)
+ new EmptyRDD[InternalRow](sparkSession.sparkContext)
} else {
new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
}
@@ -278,7 +278,7 @@ class HadoopTableReader(
val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _
val rdd = new HadoopRDD(
- sc.sparkContext,
+ sparkSession.sparkContext,
_broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]],
Some(initializeJobConfFunc),
inputFormatClass,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 9240f9c7d2..08d4b99d30 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
@@ -42,7 +42,7 @@ case class CreateTableAsSelect(
override def children: Seq[LogicalPlan] = Seq(query)
- override def run(sqlContext: SQLContext): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
lazy val metastoreRelation: MetastoreRelation = {
import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
@@ -68,24 +68,24 @@ case class CreateTableAsSelect(
withFormat
}
- sqlContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false)
+ sparkSession.sessionState.catalog.createTable(withSchema, ignoreIfExists = false)
// Get the Metastore Relation
- sqlContext.sessionState.catalog.lookupRelation(tableIdentifier) match {
+ sparkSession.sessionState.catalog.lookupRelation(tableIdentifier) match {
case r: MetastoreRelation => r
}
}
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
// processing.
- if (sqlContext.sessionState.catalog.tableExists(tableIdentifier)) {
+ if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) {
if (allowExisting) {
// table already exists, will do nothing, to keep consistent with Hive
} else {
throw new AnalysisException(s"$tableIdentifier already exists.")
}
} else {
- sqlContext.executePlan(InsertIntoTable(
+ sparkSession.executePlan(InsertIntoTable(
metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 0f72091096..cc5bbf59db 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution._
@@ -48,8 +48,8 @@ case class HiveTableScanExec(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
partitionPruningPred: Seq[Expression])(
- @transient val context: SQLContext,
- @transient val hiveconf: HiveConf)
+ @transient private val sparkSession: SparkSession,
+ @transient private val hiveconf: HiveConf)
extends LeafExecNode {
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
@@ -84,7 +84,7 @@ case class HiveTableScanExec(
@transient
private[this] val hadoopReader =
- new HadoopTableReader(attributes, relation, context, hiveExtraConf)
+ new HadoopTableReader(attributes, relation, sparkSession, hiveExtraConf)
private[this] def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 1095f5fd95..cb49fc910b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{HadoopRDD, RDD}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
@@ -52,17 +52,17 @@ private[sql] class DefaultSource
override def toString: String = "ORC"
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
OrcFileOperator.readSchema(
files.map(_.getPath.toUri.toString),
- Some(new Configuration(sqlContext.sessionState.hadoopConf))
+ Some(new Configuration(sparkSession.sessionState.hadoopConf))
)
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -109,15 +109,15 @@ private[sql] class DefaultSource
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
- val orcConf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val orcConf = new Configuration(sparkSession.sessionState.hadoopConf)
- if (sqlContext.conf.orcFilterPushDown) {
+ if (sparkSession.sessionState.conf.orcFilterPushDown) {
// Sets pushed predicates
OrcFilters.createFilter(filters.toArray).foreach { f =>
orcConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo)
@@ -125,7 +125,8 @@ private[sql] class DefaultSource
}
}
- val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(orcConf))
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(orcConf))
(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
@@ -270,7 +271,7 @@ private[orc] class OrcOutputWriter(
}
private[orc] case class OrcTableScan(
- @transient sqlContext: SQLContext,
+ @transient sparkSession: SparkSession,
attributes: Seq[Attribute],
filters: Array[Filter],
@transient inputPaths: Seq[FileStatus])
@@ -278,11 +279,11 @@ private[orc] case class OrcTableScan(
with HiveInspectors {
def execute(): RDD[InternalRow] = {
- val job = Job.getInstance(new Configuration(sqlContext.sessionState.hadoopConf))
+ val job = Job.getInstance(new Configuration(sparkSession.sessionState.hadoopConf))
val conf = job.getConfiguration
// Tries to push down filters if ORC filter push-down is enabled
- if (sqlContext.conf.orcFilterPushDown) {
+ if (sparkSession.sessionState.conf.orcFilterPushDown) {
OrcFilters.createFilter(filters).foreach { f =>
conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo)
conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true)
@@ -294,14 +295,14 @@ private[orc] case class OrcTableScan(
val orcFormat = new DefaultSource
val dataSchema =
orcFormat
- .inferSchema(sqlContext, Map.empty, inputPaths)
+ .inferSchema(sparkSession, Map.empty, inputPaths)
.getOrElse(sys.error("Failed to read schema from target ORC files."))
// Sets requested columns
OrcRelation.setRequiredColumns(conf, dataSchema, StructType.fromAttributes(attributes))
if (inputPaths.isEmpty) {
// the input path probably be pruned, return an empty RDD.
- return sqlContext.sparkContext.emptyRDD[InternalRow]
+ return sparkSession.sparkContext.emptyRDD[InternalRow]
}
FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*)
@@ -309,7 +310,7 @@ private[orc] case class OrcTableScan(
classOf[OrcInputFormat]
.asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]]
- val rdd = sqlContext.sparkContext.hadoopRDD(
+ val rdd = sparkSession.sparkContext.hadoopRDD(
conf.asInstanceOf[JobConf],
inputFormatClass,
classOf[NullWritable],
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 04b2494043..f74e5cd6f5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -71,7 +71,9 @@ object TestHive
* hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of
* test cases that rely on TestHive must be serialized.
*/
-class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootContext: Boolean)
+class TestHiveContext(
+ @transient override val sparkSession: TestHiveSparkSession,
+ isRootContext: Boolean)
extends SQLContext(sparkSession, isRootContext) {
def this(sc: SparkContext) {
@@ -106,11 +108,11 @@ class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootC
private[hive] class TestHiveSparkSession(
- sc: SparkContext,
+ @transient private val sc: SparkContext,
val warehousePath: File,
scratchDirPath: File,
metastoreTemporaryConf: Map[String, String],
- existingSharedState: Option[TestHiveSharedState])
+ @transient private val existingSharedState: Option[TestHiveSharedState])
extends SparkSession(sc) with Logging { self =>
def this(sc: SparkContext) {
@@ -463,7 +465,7 @@ private[hive] class TestHiveSparkSession(
private[hive] class TestHiveQueryExecution(
sparkSession: TestHiveSparkSession,
logicalPlan: LogicalPlan)
- extends QueryExecution(new SQLContext(sparkSession), logicalPlan) with Logging {
+ extends QueryExecution(sparkSession, logicalPlan) with Logging {
def this(sparkSession: TestHiveSparkSession, sql: String) {
this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql))
@@ -525,7 +527,7 @@ private[hive] class TestHiveSharedState(
private[hive] class TestHiveSessionState(sparkSession: TestHiveSparkSession)
- extends HiveSessionState(new SQLContext(sparkSession)) {
+ extends HiveSessionState(sparkSession) {
override lazy val conf: SQLConf = {
new SQLConf {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
index b121600dae..27c9e992de 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
@@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin)
}
- checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext, plan))
+ checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan))
}
protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 5965cdc81c..7cd01c9104 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -701,7 +701,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
// Manually create a metastore data source table.
CreateDataSourceTableUtils.createDataSourceTable(
- sqlContext = sqlContext,
+ sparkSession = sqlContext.sparkSession,
tableIdent = TableIdentifier("wide_schema"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
@@ -910,7 +910,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType)))
CreateDataSourceTableUtils.createDataSourceTable(
- sqlContext = sqlContext,
+ sparkSession = sqlContext.sparkSession,
tableIdent = TableIdentifier("not_skip_hive_metadata"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
@@ -925,7 +925,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
.forall(column => DataTypeParser.parse(column.dataType) == StringType))
CreateDataSourceTableUtils.createDataSourceTable(
- sqlContext = sqlContext,
+ sparkSession = sqlContext.sparkSession,
tableIdent = TableIdentifier("skip_hive_metadata"),
userSpecifiedSchema = Some(schema),
partitionColumns = Array.empty[String],
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index bc87d3ef38..b16c9c133b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -975,7 +975,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
- val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan)
+ val newActual = Dataset.ofRows(sqlContext.sparkSession, actual.logicalPlan)
QueryTest.checkAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
index 4a2d190353..5a8a7f0ab5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType
@@ -33,7 +33,7 @@ class CommitFailureTestSource extends SimpleTextSource {
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
*/
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory =
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index eced8ed57f..e4bd1f93c5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
-import org.apache.spark.sql.{sources, Row, SQLContext}
+import org.apache.spark.sql.{sources, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
@@ -37,14 +37,14 @@ class SimpleTextSource extends FileFormat with DataSourceRegister {
override def shortName(): String = "test"
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(DataType.fromJson(options("dataSchema")).asInstanceOf[StructType])
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory {
@@ -58,7 +58,7 @@ class SimpleTextSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
@@ -74,9 +74,9 @@ class SimpleTextSource extends FileFormat with DataSourceRegister {
inputAttributes.find(_.name == field.name)
}
- val conf = new Configuration(sqlContext.sessionState.hadoopConf)
+ val conf = new Configuration(sparkSession.sessionState.hadoopConf)
val broadcastedConf =
- sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(conf))
(file: PartitionedFile) => {
val predicate = {