aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-25 20:54:31 -0700
committerReynold Xin <rxin@databricks.com>2016-04-25 20:54:31 -0700
commit18c2c92580bdc27aa5129d9e7abda418a3633ea6 (patch)
tree5d4222dd42ea584ad8259fbe7e22f1eb9b5bdb3c /mllib
parentfa3c06987e6148975dd54b629bd9094224358175 (diff)
downloadspark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.tar.gz
spark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.tar.bz2
spark-18c2c92580bdc27aa5129d9e7abda418a3633ea6.zip
[SPARK-14861][SQL] Replace internal usages of SQLContext with SparkSession
## What changes were proposed in this pull request? In Spark 2.0, `SparkSession` is the new thing. Internally we should stop using `SQLContext` everywhere since that's supposed to be not the main user-facing API anymore. In this patch I took care to not break any public APIs. The one place that's suspect is `o.a.s.ml.source.libsvm.DefaultSource`, but according to mengxr it's not supposed to be public so it's OK to change the underlying `FileFormat` trait. **Reviewers**: This is a big patch that may be difficult to review but the changes are actually really straightforward. If you prefer I can break it up into a few smaller patches, but it will delay the progress of this issue a little. ## How was this patch tested? No change in functionality intended. Author: Andrew Or <andrew@databricks.com> Closes #12625 from andrewor14/spark-session-refactor.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala8
14 files changed, 52 insertions, 51 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 39a698af15..1bd6dae7dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -199,7 +199,7 @@ final class GBTClassificationModel private[ml](
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c2b440059b..ec7fc977b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -840,8 +840,8 @@ class BinaryLogisticRegressionSummary private[classification] (
@Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
- private val sqlContext = predictions.sqlContext
- import sqlContext.implicits._
+ private val sparkSession = predictions.sparkSession
+ import sparkSession.implicits._
/**
* Returns a BinaryClassificationMetrics object.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index dfa711b243..c04ecc88ae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -181,7 +181,7 @@ final class RandomForestClassificationModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index c57ceba4a9..27c813dd61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
import org.apache.spark.sql.types.StructType
@@ -356,14 +356,14 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]].
*
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
- * @param sqlContext Used to construct local DataFrames for returning query results
+ * @param sparkSession Used to construct local DataFrames for returning query results
*/
@Since("1.6.0")
@Experimental
-sealed abstract class LDAModel private[ml] (
+sealed abstract class LDAModel protected[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
- @Since("1.6.0") @transient protected val sqlContext: SQLContext)
+ @Since("1.6.0") @transient protected[ml] val sparkSession: SparkSession)
extends Model[LDAModel] with LDAParams with Logging with MLWritable {
// NOTE to developers:
@@ -405,7 +405,7 @@ sealed abstract class LDAModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
- val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
+ val t = udf(oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
@@ -495,7 +495,7 @@ sealed abstract class LDAModel private[ml] (
case ((termIndices, termWeights), topic) =>
(topic, termIndices.toSeq, termWeights.toSeq)
}
- sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights")
+ sparkSession.createDataFrame(topics).toDF("topic", "termIndices", "termWeights")
}
@Since("1.6.0")
@@ -512,16 +512,16 @@ sealed abstract class LDAModel private[ml] (
*/
@Since("1.6.0")
@Experimental
-class LocalLDAModel private[ml] (
+class LocalLDAModel protected[ml] (
uid: String,
vocabSize: Int,
@Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
- sqlContext: SQLContext)
- extends LDAModel(uid, vocabSize, sqlContext) {
+ sparkSession: SparkSession)
+ extends LDAModel(uid, vocabSize, sparkSession) {
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
- val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
}
@@ -554,7 +554,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -565,7 +565,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
override def load(path: String): LocalLDAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
@@ -576,7 +576,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val gammaShape = data.getAs[Double](4)
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
- val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
+ val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
@@ -604,13 +604,13 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
*/
@Since("1.6.0")
@Experimental
-class DistributedLDAModel private[ml] (
+class DistributedLDAModel protected[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
private var oldLocalModelOption: Option[OldLocalLDAModel])
- extends LDAModel(uid, vocabSize, sqlContext) {
+ extends LDAModel(uid, vocabSize, sparkSession) {
override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
@@ -628,12 +628,12 @@ class DistributedLDAModel private[ml] (
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
- def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
- val copied =
- new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
+ val copied = new DistributedLDAModel(
+ uid, vocabSize, oldDistributedModel, sparkSession, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}
@@ -688,7 +688,7 @@ class DistributedLDAModel private[ml] (
@DeveloperApi
@Since("2.0.0")
def deleteCheckpointFiles(): Unit = {
- val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ val fs = FileSystem.get(sparkSession.sparkContext.hadoopConfiguration)
_checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
_checkpointFiles = Array.empty[String]
}
@@ -720,7 +720,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(
- metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
+ metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
@@ -856,9 +856,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData)
val newModel = oldModel match {
case m: OldLocalLDAModel =>
- new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
+ new LocalLDAModel(uid, m.vocabSize, m, dataset.sparkSession)
case m: OldDistributedLDAModel =>
- new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
+ new DistributedLDAModel(uid, m.vocabSize, m, dataset.sparkSession, None)
}
copyValues(newModel).setParent(this)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 922670a41b..3fbfce9d48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -230,7 +230,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
- broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
+ broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict))
}
val dictBr = broadcastDict.get
val minTf = $(minTF)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 2002d15745..400435d7a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,7 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- dataset.sqlContext.sql(realStatement)
+ dataset.sparkSession.sql(realStatement)
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index a72692960f..c49e263df0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -226,7 +226,7 @@ class Word2VecModel private[ml] (
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
- val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+ val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 728b4d1598..6e4e6a6d85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -387,7 +387,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ALSModel = {
- import dataset.sqlContext.implicits._
+ import dataset.sparkSession.implicits._
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 741724d7a1..da51cb7800 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -186,7 +186,7 @@ final class GBTRegressionModel private[ml](
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 4c4ff278d4..8eaed8b682 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -165,7 +165,7 @@ final class RandomForestRegressionModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index dc2a6f5275..f07374ac0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
@@ -85,12 +85,12 @@ private[libsvm] class LibSVMOutputWriter(
* optionally specify options, for example:
* {{{
* // Scala
- * val df = sqlContext.read.format("libsvm")
+ * val df = spark.read.format("libsvm")
* .option("numFeatures", "780")
* .load("data/mllib/sample_libsvm_data.txt")
*
* // Java
- * DataFrame df = sqlContext.read().format("libsvm")
+ * DataFrame df = spark.read().format("libsvm")
* .option("numFeatures, "780")
* .load("data/mllib/sample_libsvm_data.txt");
* }}}
@@ -123,7 +123,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def inferSchema(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(
@@ -133,7 +133,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareRead(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = {
def computeNumFeatures(): Int = {
@@ -146,7 +146,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
throw new IOException("Multiple input paths are not supported for libsvm data.")
}
- val sc = sqlContext.sparkContext
+ val sc = sparkSession.sparkContext
val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
MLUtils.computeNumFeatures(parsed)
}
@@ -159,7 +159,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def prepareWrite(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
@@ -176,7 +176,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def buildReader(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
@@ -188,9 +188,9 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
- val broadcastedConf = sqlContext.sparkContext.broadcast(
- new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
- )
+ val broadcastedConf = sparkSession.sparkContext.broadcast(
+ new SerializableConfiguration(
+ new Configuration(sparkSession.sparkContext.hadoopConfiguration)))
(file: PartitionedFile) => {
val points =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index de563d4fad..a41d02cde7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -94,7 +94,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
- val sqlCtx = dataset.sqlContext
+ val sparkSession = dataset.sparkSession
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
@@ -102,8 +102,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
- val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+ val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
+ val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 12d6905510..f2b7badbe5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -94,7 +94,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
- val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 7dec07ea14..94d1b83ec2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
/**
@@ -61,8 +61,10 @@ private[util] sealed trait BaseReadWrite {
optionSQLContext.get
}
- /** Returns the [[SparkContext]] underlying [[sqlContext]] */
- protected final def sc: SparkContext = sqlContext.sparkContext
+ protected final def sparkSession: SparkSession = sqlContext.sparkSession
+
+ /** Returns the underlying [[SparkContext]]. */
+ protected final def sc: SparkContext = sparkSession.sparkContext
}
/**