From 18c2c92580bdc27aa5129d9e7abda418a3633ea6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 25 Apr 2016 20:54:31 -0700 Subject: [SPARK-14861][SQL] Replace internal usages of SQLContext with SparkSession ## What changes were proposed in this pull request? In Spark 2.0, `SparkSession` is the new thing. Internally we should stop using `SQLContext` everywhere since that's supposed to be not the main user-facing API anymore. In this patch I took care to not break any public APIs. The one place that's suspect is `o.a.s.ml.source.libsvm.DefaultSource`, but according to mengxr it's not supposed to be public so it's OK to change the underlying `FileFormat` trait. **Reviewers**: This is a big patch that may be difficult to review but the changes are actually really straightforward. If you prefer I can break it up into a few smaller patches, but it will delay the progress of this issue a little. ## How was this patch tested? No change in functionality intended. Author: Andrew Or Closes #12625 from andrewor14/spark-session-refactor. --- .../spark/ml/classification/GBTClassifier.scala | 2 +- .../ml/classification/LogisticRegression.scala | 4 +- .../ml/classification/RandomForestClassifier.scala | 2 +- .../scala/org/apache/spark/ml/clustering/LDA.scala | 46 +++++++++++----------- .../apache/spark/ml/feature/CountVectorizer.scala | 2 +- .../apache/spark/ml/feature/SQLTransformer.scala | 2 +- .../org/apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../spark/ml/source/libsvm/LibSVMRelation.scala | 22 +++++------ .../apache/spark/ml/tuning/CrossValidator.scala | 6 +-- .../spark/ml/tuning/TrainValidationSplit.scala | 1 - .../scala/org/apache/spark/ml/util/ReadWrite.scala | 8 ++-- 14 files changed, 52 insertions(+), 51 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 39a698af15..1bd6dae7dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -199,7 +199,7 @@ final class GBTClassificationModel private[ml]( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c2b440059b..ec7fc977b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -840,8 +840,8 @@ class BinaryLogisticRegressionSummary private[classification] ( @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { - private val sqlContext = predictions.sqlContext - import sqlContext.implicits._ + private val sparkSession = predictions.sparkSession + import sparkSession.implicits._ /** * Returns a BinaryClassificationMetrics object. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index dfa711b243..c04ecc88ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -181,7 +181,7 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index c57ceba4a9..27c813dd61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType @@ -356,14 +356,14 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * Model fitted by [[LDA]]. * * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) - * @param sqlContext Used to construct local DataFrames for returning query results + * @param sparkSession Used to construct local DataFrames for returning query results */ @Since("1.6.0") @Experimental -sealed abstract class LDAModel private[ml] ( +sealed abstract class LDAModel protected[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, - @Since("1.6.0") @transient protected val sqlContext: SQLContext) + @Since("1.6.0") @transient protected[ml] val sparkSession: SparkSession) extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: @@ -405,7 +405,7 @@ sealed abstract class LDAModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { - val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) + val t = udf(oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)) dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + @@ -495,7 +495,7 @@ sealed abstract class LDAModel private[ml] ( case ((termIndices, termWeights), topic) => (topic, termIndices.toSeq, termWeights.toSeq) } - sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + sparkSession.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") } @Since("1.6.0") @@ -512,16 +512,16 @@ sealed abstract class LDAModel private[ml] ( */ @Since("1.6.0") @Experimental -class LocalLDAModel private[ml] ( +class LocalLDAModel protected[ml] ( uid: String, vocabSize: Int, @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, - sqlContext: SQLContext) - extends LDAModel(uid, vocabSize, sqlContext) { + sparkSession: SparkSession) + extends LDAModel(uid, vocabSize, sparkSession) { @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { - val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] } @@ -554,7 +554,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, oldModel.topicConcentration, oldModel.gammaShape) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -565,7 +565,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { override def load(path: String): LocalLDAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", "gammaShape") .head() @@ -576,7 +576,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val gammaShape = data.getAs[Double](4) val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) - val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -604,13 +604,13 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { */ @Since("1.6.0") @Experimental -class DistributedLDAModel private[ml] ( +class DistributedLDAModel protected[ml] ( uid: String, vocabSize: Int, private val oldDistributedModel: OldDistributedLDAModel, - sqlContext: SQLContext, + sparkSession: SparkSession, private var oldLocalModelOption: Option[OldLocalLDAModel]) - extends LDAModel(uid, vocabSize, sqlContext) { + extends LDAModel(uid, vocabSize, sparkSession) { override protected def oldLocalModel: OldLocalLDAModel = { if (oldLocalModelOption.isEmpty) { @@ -628,12 +628,12 @@ class DistributedLDAModel private[ml] ( * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. */ @Since("1.6.0") - def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) @Since("1.6.0") override def copy(extra: ParamMap): DistributedLDAModel = { - val copied = - new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) + val copied = new DistributedLDAModel( + uid, vocabSize, oldDistributedModel, sparkSession, oldLocalModelOption) copyValues(copied, extra).setParent(parent) copied } @@ -688,7 +688,7 @@ class DistributedLDAModel private[ml] ( @DeveloperApi @Since("2.0.0") def deleteCheckpointFiles(): Unit = { - val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + val fs = FileSystem.get(sparkSession.sparkContext.hadoopConfiguration) _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) _checkpointFiles = Array.empty[String] } @@ -720,7 +720,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) val model = new DistributedLDAModel( - metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -856,9 +856,9 @@ class LDA @Since("1.6.0") ( val oldModel = oldLDA.run(oldData) val newModel = oldModel match { case m: OldLocalLDAModel => - new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + new LocalLDAModel(uid, m.vocabSize, m, dataset.sparkSession) case m: OldDistributedLDAModel => - new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sparkSession, None) } copyValues(newModel).setParent(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 922670a41b..3fbfce9d48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -230,7 +230,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap - broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) + broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict)) } val dictBr = broadcastDict.get val minTf = $(minTF) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 2002d15745..400435d7a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -68,7 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - dataset.sqlContext.sql(realStatement) + dataset.sparkSession.sql(realStatement) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index a72692960f..c49e263df0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -226,7 +226,7 @@ class Word2VecModel private[ml] ( val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) .map(identity) // mapValues doesn't return a serializable map (SI-7005) - val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 728b4d1598..6e4e6a6d85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -387,7 +387,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") override def fit(dataset: Dataset[_]): ALSModel = { - import dataset.sqlContext.implicits._ + import dataset.sparkSession.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 741724d7a1..da51cb7800 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -186,7 +186,7 @@ final class GBTRegressionModel private[ml]( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 4c4ff278d4..8eaed8b682 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -165,7 +165,7 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index dc2a6f5275..f07374ac0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow} @@ -85,12 +85,12 @@ private[libsvm] class LibSVMOutputWriter( * optionally specify options, for example: * {{{ * // Scala - * val df = sqlContext.read.format("libsvm") + * val df = spark.read.format("libsvm") * .option("numFeatures", "780") * .load("data/mllib/sample_libsvm_data.txt") * * // Java - * DataFrame df = sqlContext.read().format("libsvm") + * DataFrame df = spark.read().format("libsvm") * .option("numFeatures, "780") * .load("data/mllib/sample_libsvm_data.txt"); * }}} @@ -123,7 +123,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def inferSchema( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { Some( @@ -133,7 +133,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def prepareRead( - sqlContext: SQLContext, + sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Map[String, String] = { def computeNumFeatures(): Int = { @@ -146,7 +146,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { throw new IOException("Multiple input paths are not supported for libsvm data.") } - val sc = sqlContext.sparkContext + val sc = sparkSession.sparkContext val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) MLUtils.computeNumFeatures(parsed) } @@ -159,7 +159,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def prepareWrite( - sqlContext: SQLContext, + sparkSession: SparkSession, job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { @@ -176,7 +176,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def buildReader( - sqlContext: SQLContext, + sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, requiredSchema: StructType, @@ -188,9 +188,9 @@ class DefaultSource extends FileFormat with DataSourceRegister { val sparse = options.getOrElse("vectorType", "sparse") == "sparse" - val broadcastedConf = sqlContext.sparkContext.broadcast( - new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration)) - ) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration( + new Configuration(sparkSession.sparkContext.hadoopConfiguration))) (file: PartitionedFile) => { val points = diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index de563d4fad..a41d02cde7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -94,7 +94,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) - val sqlCtx = dataset.sqlContext + val sparkSession = dataset.sparkSession val est = $(estimator) val eval = $(evaluator) val epm = $(estimatorParamMaps) @@ -102,8 +102,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = new Array[Double](epm.length) val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() - val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + val trainingDataset = sparkSession.createDataFrame(training, schema).cache() + val validationDataset = sparkSession.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 12d6905510..f2b7badbe5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -94,7 +94,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) - val sqlCtx = dataset.sqlContext val est = $(estimator) val eval = $(evaluator) val epm = $(estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7dec07ea14..94d1b83ec2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils /** @@ -61,8 +61,10 @@ private[util] sealed trait BaseReadWrite { optionSQLContext.get } - /** Returns the [[SparkContext]] underlying [[sqlContext]] */ - protected final def sc: SparkContext = sqlContext.sparkContext + protected final def sparkSession: SparkSession = sqlContext.sparkSession + + /** Returns the underlying [[SparkContext]]. */ + protected final def sc: SparkContext = sparkSession.sparkContext } /** -- cgit v1.2.3