From 85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 31 May 2016 17:40:44 -0700 Subject: [SPARK-15618][SQL][MLLIB] Use SparkSession.builder.sparkContext if applicable. ## What changes were proposed in this pull request? This PR changes function `SparkSession.builder.sparkContext(..)` from **private[sql]** into **private[spark]**, and uses it if applicable like the followings. ``` - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() ``` ## How was this patch tested? Pass the existing Jenkins tests. Author: Dongjoon Hyun Closes #13365 from dongjoon-hyun/SPARK-15618. --- .../scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- .../apache/spark/mllib/classification/LogisticRegression.scala | 2 +- .../scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../spark/mllib/classification/impl/GLMClassificationModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/BisectingKMeansModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 4 ++-- .../scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/clustering/LDAModel.scala | 8 ++++---- .../apache/spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 4 ++-- .../spark/mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../org/apache/spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../apache/spark/mllib/regression/impl/GLMRegressionModel.scala | 4 ++-- .../org/apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../org/apache/spark/mllib/tree/model/treeEnsembleModels.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala | 8 ++------ .../org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala | 6 +++--- .../test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 3 ++- .../src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala | 3 ++- 22 files changed, 49 insertions(+), 51 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6e0ed374c7..e43469bf1c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. val sc = indexedRowMatrix.rows.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(indexedRowMatrix.rows) } @@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. val sc = coordinateMatrix.entries.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(coordinateMatrix.entries) } @@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. val sc = blockMatrix.blocks.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b186ca3770..e4cc784cfe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) // Convert our input into a DataFrame - val spark = SparkSession.builder().config(input.context.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() val df = spark.createDataFrame(input.map(_.asML)) // Determine if we should cache the DF val handlePersistence = input.getStorageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 452802f043..593a86f69a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 32e323d080..84491181d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 510a91b5a7..b3546a1ee3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rootId" -> model.root.index))) @@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { } def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 31ad56dba6..c30cc3e239 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { path: String, weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render @@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 5f939c1a21..aa78149699 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) @@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 1b66013d54..4f07236225 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -446,7 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { docConcentration: Vector, topicConcentration: Double, gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -470,7 +470,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topicConcentration: Double, gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) Loader.checkSchema[Data](dataFrame.schema) @@ -851,7 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { topicConcentration: Double, iterationTimes: Array[Double], gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -887,7 +887,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) val vertexDataFrame = spark.read.parquet(vertexDataPath) val edgeDataFrame = spark.read.parquet(edgeDataPath) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 51077bd630..c760ddd6ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -82,7 +82,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 13decefcd6..c8c2823bba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -149,7 +149,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9bd79aa7c6..2f52825c6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -609,7 +609,7 @@ object Word2VecModel extends Loader[Word2VecModel] { case class Data(word: String, vector: Array[Float]) def load(sc: SparkContext, path: String): Word2VecModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(Loader.dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -620,7 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val vectorSize = model.values.head.length val numWords = model.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 8c0639baea..0f7fbe9556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def save(model: FPGrowthModel[_], path: String): Unit = { val sc = model.freqItemsets.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -123,7 +123,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def load(sc: SparkContext, path: String): FPGrowthModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 10bbcd2a3d..c13c794775 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def save(model: PrefixSpanModel[_], path: String): Unit = { val sc = model.freqSequences.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -640,7 +640,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 450025f477..c642573ccb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -354,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -365,7 +365,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 215a799b96..1cd6f2a896 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 3c7bbc5244..cd90e97cc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -68,7 +68,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 72663188a9..a1562384b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -233,13 +233,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Create Parquet data. val nodes = model.topNode.subtreeIterator.toSeq val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _)) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { // Load Parquet data. - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataPath = Loader.dataPath(path) val dataRDD = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index c653b988e2..f7d9b22b6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -413,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // SPARK-6120: We do a hacky check here so users understand why save() is failing // when they run the ML guide example. @@ -471,7 +471,7 @@ private[tree] object TreeEnsembleModel extends Logging { sc: SparkContext, path: String, treeAlgo: String): Array[DecisionTreeModel] = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply) val trees = constructTrees(nodes.rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 40d5b4881f..3558290b23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val spark = SparkSession.builder - .master("local[2]") - .appName("ChiSqSelectorSuite") - .getOrCreate() + val spark = this.spark import spark.implicits._ - val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 621c13a8e5..b73dbd6232 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,7 +27,7 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test observed number of buckets and their sizes match expected values") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val datasetSize = 100000 @@ -53,7 +53,7 @@ class QuantileDiscretizerSuite } test("Test transform method on unseen data") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") @@ -82,7 +82,7 @@ class QuantileDiscretizerSuite } test("Verify resulting model has parent") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 59b5edc401..e8ed50acf8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite { val spark = SparkSession.builder .master("local[2]") .appName("ALSCleanerSuite") + .sparkContext(sc) .getOrCreate() import spark.implicits._ val als = new ALS() @@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite { val pattern = "shuffle_(\\d+)_.+\\.data".r val rddIds = resultingFiles.flatMap { f => pattern.findAllIn(f.getName()).matchData.map { _.group(1) } } - assert(rddIds.toSet.size === 4) + assert(rddIds.size === 4) } finally { sc.stop() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 8cbd652bac..d2fa8d0d63 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val spark = SparkSession.builder + val spark = SparkSession.builder() .master("local[2]") .appName("TreeTests") + .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ -- cgit v1.2.3