aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-06-21 23:12:08 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-06-21 23:12:08 -0700
commit0e3ce75332dd536c0db8467d456ad46e4bf228f4 (patch)
tree2f9d0397ec0250072cdc8713f19c4dc432ccc790
parent7580f3041a1a3757a0b14b9d8afeb720f261fff6 (diff)
downloadspark-0e3ce75332dd536c0db8467d456ad46e4bf228f4.tar.gz
spark-0e3ce75332dd536c0db8467d456ad46e4bf228f4.tar.bz2
spark-0e3ce75332dd536c0db8467d456ad46e4bf228f4.zip
[SPARK-15644][MLLIB][SQL] Replace SQLContext with SparkSession in MLlib
#### What changes were proposed in this pull request? This PR is to use the latest `SparkSession` to replace the existing `SQLContext` in `MLlib`. `SQLContext` is removed from `MLlib`. Also fix a test case issue in `BroadcastJoinSuite`. BTW, `SQLContext` is not being used in the `MLlib` test suites. #### How was this patch tested? Existing test cases. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #13380 from gatorsmile/sqlContextML.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala41
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala2
31 files changed, 100 insertions, 81 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 881dcefb79..c65d3d5b54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -243,7 +243,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
+ sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
}
}
@@ -258,7 +258,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
- val root = loadTreeNodes(path, metadata, sqlContext)
+ val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index f843df449c..4e534baddc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -270,7 +270,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
- EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
@@ -283,7 +283,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 9469acf62e..a7ba39e432 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -660,7 +660,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
instance.coefficients)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -674,7 +674,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.format("parquet").load(dataPath)
+ val data = sparkSession.read.format("parquet").load(dataPath)
.select("numClasses", "numFeatures", "intercept", "coefficients").head()
// We will need numClasses, numFeatures in the future for multinomial logreg support.
// val numClasses = data.getInt(0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index c4e882240f..700542117e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -356,7 +356,7 @@ object MultilayerPerceptronClassificationModel
// Save model data: layers, weights
val data = Data(instance.layers, instance.weights)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -370,7 +370,7 @@ object MultilayerPerceptronClassificationModel
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
+ val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head()
val layers = data.getAs[Seq[Int]](0).toArray
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index a98bdeca6b..a9d493032b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -262,7 +262,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
// Save model data: pi, theta
val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -275,7 +275,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
+ val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val model = new NaiveBayesModel(metadata.uid, pi, theta)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index b3c074f839..9a26a5c5b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -282,7 +282,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses,
"numTrees" -> instance.getNumTrees)
- EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
@@ -296,7 +296,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
- EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 563a3b14e9..81749055c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -195,7 +195,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -208,7 +208,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
+ val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
val weights = row.getSeq[Double](0).toArray
val mus = row.getSeq[OldVector](1).toArray
val sigmas = row.getSeq[OldMatrix](2).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 790ef1fe8d..6f63d04818 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -211,7 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
Data(idx, center)
}
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}
@@ -222,8 +222,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
override def load(path: String): KMeansModel = {
// Import implicits for Dataset Encoder
- val sqlContext = super.sqlContext
- import sqlContext.implicits._
+ val sparkSession = super.sparkSession
+ import sparkSession.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
@@ -232,11 +232,11 @@ object KMeansModel extends MLReadable[KMeansModel] {
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
- val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+ val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
- sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
+ sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 33723287b0..38b4db99bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -202,7 +202,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.selectedFeatures.toSeq)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -213,7 +213,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
+ val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 3250fe5598..96e6f1c512 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -297,7 +297,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.vocabulary)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -308,7 +308,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("vocabulary")
.head()
val vocabulary = data.getAs[Seq[String]](0).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index cf03a2845c..02d4e6a9f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -168,7 +168,7 @@ object IDFModel extends MLReadable[IDFModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.idf)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -179,7 +179,7 @@ object IDFModel extends MLReadable[IDFModel] {
override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("idf")
.head()
val idf = data.getAs[Vector](0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 31a5815267..acabf0b892 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -161,7 +161,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
override def load(path: String): MaxAbsScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath)
+ val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
.select("maxAbs")
.head()
val model = new MaxAbsScalerModel(metadata.uid, maxAbs)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index dd5a1f9b41..562b3f38e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -221,7 +221,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = new Data(instance.originalMin, instance.originalMax)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -232,7 +232,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
+ val Row(originalMin: Vector, originalMax: Vector) = sparkSession.read.parquet(dataPath)
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index b89c85991f..72167b50e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -186,7 +186,7 @@ object PCAModel extends MLReadable[PCAModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.pc, instance.explainedVariance)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -217,12 +217,12 @@ object PCAModel extends MLReadable[PCAModel] {
val dataPath = new Path(path, "data").toString
val model = if (hasExplainedVariance) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
- sqlContext.read.parquet(dataPath)
+ sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
.head()
new PCAModel(metadata.uid, pc, explainedVariance)
} else {
- val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head()
+ val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
DefaultParamsReader.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 546dc7e8c0..c95dacfce8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -297,7 +297,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(instance.resolvedFormula))
+ sparkSession.createDataFrame(Seq(instance.resolvedFormula))
.repartition(1).write.parquet(dataPath)
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
@@ -314,7 +314,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
+ val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
val label = data.getString(0)
val terms = data.getAs[Seq[Seq[String]]](1)
val hasIntercept = data.getBoolean(2)
@@ -372,7 +372,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -385,7 +385,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
+ val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head()
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
@@ -463,7 +463,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -476,7 +476,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
+ val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
val vectorCol = data.getString(0)
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 5e1bacf876..be58dc27e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -200,7 +200,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -211,7 +211,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
+ val Row(std: Vector, mean: Vector) = sparkSession.read.parquet(dataPath)
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 0f7337ce6b..028e540fe5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -221,7 +221,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.labels)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -232,7 +232,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("labels")
.head()
val labels = data.getAs[Seq[String]](0).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 52db996c84..5656a9f979 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -450,7 +450,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.numFeatures, instance.categoryMaps)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -461,7 +461,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
override def load(path: String): VectorIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("numFeatures", "categoryMaps")
.head()
val numFeatures = data.getAs[Int](0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 05c4f2f1a7..a74d31ff9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -310,7 +310,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -321,7 +321,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
override def load(path: String): Word2VecModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("wordIndex", "wordVectors")
.head()
val wordIndex = data.getAs[Map[String, Int]](0)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2404a69e9e..5dc2433e55 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -320,9 +320,9 @@ object ALSModel extends MLReadable[ALSModel] {
implicit val format = DefaultFormats
val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
- val userFactors = sqlContext.read.format("parquet").load(userPath)
+ val userFactors = sparkSession.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
- val itemFactors = sqlContext.read.format("parquet").load(itemPath)
+ val itemFactors = sparkSession.read.format("parquet").load(itemPath)
val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 7f57af19e9..fe65e3e810 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -375,7 +375,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
// Save model data: coefficients, intercept, scale
val data = Data(instance.coefficients, instance.intercept, instance.scale)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -388,7 +388,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("coefficients", "intercept", "scale").head()
val coefficients = data.getAs[Vector](0)
val intercept = data.getDouble(1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index c4df9d1112..7ff6d0afd5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -249,7 +249,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(nodeData).write.parquet(dataPath)
+ sparkSession.createDataFrame(nodeData).write.parquet(dataPath)
}
}
@@ -263,7 +263,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
- val root = loadTreeNodes(path, metadata, sqlContext)
+ val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 81f2139f0b..6223555504 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -252,7 +252,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
- EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
@@ -265,7 +265,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
override def load(path: String): GBTRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index adbdd345e9..a23e90d9e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -813,7 +813,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -827,7 +827,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("intercept", "coefficients").head()
val intercept = data.getDouble(0)
val coefficients = data.getAs[Vector](1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index d16e8e3f6b..f05b47eda7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -284,7 +284,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
val data = Data(
instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -297,7 +297,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
.select("boundaries", "predictions", "isotonic").head()
val boundaries = data.getAs[Seq[Double]](0).toArray
val predictions = data.getAs[Seq[Double]](1).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 52ec40e15b..5e8ef1b375 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -486,7 +486,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
@@ -499,7 +499,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.format("parquet").load(dataPath)
+ val data = sparkSession.read.format("parquet").load(dataPath)
.select("intercept", "coefficients").head()
val intercept = data.getDouble(0)
val coefficients = data.getAs[Vector](1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a6dbf21d55..4f4d3d2784 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -244,7 +244,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
- EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
@@ -257,7 +257,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 56c85c9b53..5b6fcc53c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Dataset, SQLContext}
+import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -332,8 +332,8 @@ private[ml] object DecisionTreeModelReadWrite {
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
- sqlContext: SQLContext): Node = {
- import sqlContext.implicits._
+ sparkSession: SparkSession): Node = {
+ import sparkSession.implicits._
implicit val format = DefaultFormats
// Get impurity to construct ImpurityCalculator for each node
@@ -343,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite {
}
val dataPath = new Path(path, "data").toString
- val data = sqlContext.read.parquet(dataPath).as[NodeData]
+ val data = sparkSession.read.parquet(dataPath).as[NodeData]
buildTreeFromNodes(data.collect(), impurityType)
}
@@ -393,7 +393,7 @@ private[ml] object EnsembleModelReadWrite {
def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
instance: M,
path: String,
- sql: SQLContext,
+ sql: SparkSession,
extraMetadata: JObject): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
@@ -424,7 +424,7 @@ private[ml] object EnsembleModelReadWrite {
*/
def loadImpl(
path: String,
- sql: SQLContext,
+ sql: SparkSession,
className: String,
treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
import sql.implicits._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 90b8d7df7b..1582a73ea0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -40,28 +40,41 @@ import org.apache.spark.util.Utils
* Trait for [[MLWriter]] and [[MLReader]].
*/
private[util] sealed trait BaseReadWrite {
- private var optionSQLContext: Option[SQLContext] = None
+ private var optionSparkSession: Option[SparkSession] = None
/**
- * Sets the SQL context to use for saving/loading.
+ * Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
+ @deprecated("Use session instead", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
- optionSQLContext = Option(sqlContext)
+ optionSparkSession = Option(sqlContext.sparkSession)
this
}
/**
- * Returns the user-specified SQL context or the default.
+ * Sets the Spark Session to use for saving/loading.
*/
- protected final def sqlContext: SQLContext = {
- if (optionSQLContext.isEmpty) {
- optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate()))
+ @Since("2.0.0")
+ def session(sparkSession: SparkSession): this.type = {
+ optionSparkSession = Option(sparkSession)
+ this
+ }
+
+ /**
+ * Returns the user-specified Spark Session or the default.
+ */
+ protected final def sparkSession: SparkSession = {
+ if (optionSparkSession.isEmpty) {
+ optionSparkSession = Some(SparkSession.builder().getOrCreate())
}
- optionSQLContext.get
+ optionSparkSession.get
}
- protected final def sparkSession: SparkSession = sqlContext.sparkSession
+ /**
+ * Returns the user-specified SQL context or the default.
+ */
+ protected final def sqlContext: SQLContext = sparkSession.sqlContext
/** Returns the underlying [[SparkContext]]. */
protected final def sc: SparkContext = sparkSession.sparkContext
@@ -118,7 +131,10 @@ abstract class MLWriter extends BaseReadWrite with Logging {
}
// override for Java compatibility
- override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+ override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}
/**
@@ -180,7 +196,10 @@ abstract class MLReader[T] extends BaseReadWrite {
def load(path: String): T
// override for Java compatibility
- override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+ override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 7bda219243..e4f678fef1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
} catch (IOException e) {
// expected
}
- instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
+ instance.write().session(spark).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 251f47d5fb..a3fd39d42e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -110,7 +110,7 @@ class SparkSession private(
* A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility.
*/
@transient
- private[sql] val sqlContext: SQLContext = new SQLContext(this)
+ private[spark] val sqlContext: SQLContext = new SQLContext(this)
/**
* Runtime configuration interface for Spark.