From 85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 31 May 2016 17:40:44 -0700 Subject: [SPARK-15618][SQL][MLLIB] Use SparkSession.builder.sparkContext if applicable. ## What changes were proposed in this pull request? This PR changes function `SparkSession.builder.sparkContext(..)` from **private[sql]** into **private[spark]**, and uses it if applicable like the followings. ``` - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() ``` ## How was this patch tested? Pass the existing Jenkins tests. Author: Dongjoon Hyun Closes #13365 from dongjoon-hyun/SPARK-15618. --- .../org/apache/spark/examples/BroadcastTest.scala | 7 ++-- .../apache/spark/examples/mllib/LDAExample.scala | 1 + .../spark/examples/sql/hive/HiveFromSpark.scala | 5 +-- .../spark/mllib/api/python/PythonMLLibAPI.scala | 6 ++-- .../mllib/classification/LogisticRegression.scala | 2 +- .../spark/mllib/classification/NaiveBayes.scala | 8 ++--- .../impl/GLMClassificationModel.scala | 4 +-- .../mllib/clustering/BisectingKMeansModel.scala | 4 +-- .../mllib/clustering/GaussianMixtureModel.scala | 4 +-- .../spark/mllib/clustering/KMeansModel.scala | 4 +-- .../apache/spark/mllib/clustering/LDAModel.scala | 8 ++--- .../clustering/PowerIterationClustering.scala | 4 +-- .../apache/spark/mllib/feature/ChiSqSelector.scala | 4 +-- .../org/apache/spark/mllib/feature/Word2Vec.scala | 4 +-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 4 +-- .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 4 +-- .../recommendation/MatrixFactorizationModel.scala | 4 +-- .../mllib/regression/IsotonicRegression.scala | 4 +-- .../mllib/regression/impl/GLMRegressionModel.scala | 4 +-- .../spark/mllib/tree/model/DecisionTreeModel.scala | 4 +-- .../mllib/tree/model/treeEnsembleModels.scala | 4 +-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 8 ++--- .../ml/feature/QuantileDiscretizerSuite.scala | 6 ++-- .../apache/spark/ml/recommendation/ALSSuite.scala | 3 +- .../org/apache/spark/ml/tree/impl/TreeTests.scala | 3 +- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../sql/execution/joins/BroadcastJoinSuite.scala | 41 +++++++++++----------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 13 +++---- 28 files changed, 79 insertions(+), 90 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index c50f25d951..a68fd0285f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -29,13 +29,10 @@ object BroadcastTest { val blockSize = if (args.length > 2) args(2) else "4096" - val sparkConf = new SparkConf() - .set("spark.broadcast.blockSize", blockSize) - val spark = SparkSession - .builder - .config(sparkConf) + .builder() .appName("Broadcast Test") + .config("spark.broadcast.blockSize", blockSize) .getOrCreate() val sc = spark.sparkContext diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 7651aade49..3fbf8e0333 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -191,6 +191,7 @@ object LDAExample { val spark = SparkSession .builder + .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index d3bb7e4398..2d7a01a95d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -22,7 +22,6 @@ import java.io.File import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.SparkConf import org.apache.spark.sql._ object HiveFromSpark { @@ -35,8 +34,6 @@ object HiveFromSpark { ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File)) def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("HiveFromSpark") - // When working with Hive, one must instantiate `SparkSession` with Hive support, including // connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined // functions. Users who do not have an existing Hive deployment can still enable Hive support. @@ -45,7 +42,7 @@ object HiveFromSpark { // which defaults to the directory `spark-warehouse` in the current directory that the spark // application is started. val spark = SparkSession.builder - .config(sparkConf) + .appName("HiveFromSpark") .enableHiveSupport() .getOrCreate() val sc = spark.sparkContext diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 6e0ed374c7..e43469bf1c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. val sc = indexedRowMatrix.rows.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(indexedRowMatrix.rows) } @@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. val sc = coordinateMatrix.entries.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(coordinateMatrix.entries) } @@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. val sc = blockMatrix.blocks.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b186ca3770..e4cc784cfe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) // Convert our input into a DataFrame - val spark = SparkSession.builder().config(input.context.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() val df = spark.createDataFrame(input.map(_.asML)) // Determine if we should cache the DF val handlePersistence = input.getStorageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 452802f043..593a86f69a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 32e323d080..84491181d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 510a91b5a7..b3546a1ee3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rootId" -> model.root.index))) @@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { } def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 31ad56dba6..c30cc3e239 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { path: String, weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render @@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 5f939c1a21..aa78149699 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) @@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 1b66013d54..4f07236225 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -446,7 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { docConcentration: Vector, topicConcentration: Double, gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -470,7 +470,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topicConcentration: Double, gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) Loader.checkSchema[Data](dataFrame.schema) @@ -851,7 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { topicConcentration: Double, iterationTimes: Array[Double], gammaShape: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -887,7 +887,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(dataPath) val vertexDataFrame = spark.read.parquet(vertexDataPath) val edgeDataFrame = spark.read.parquet(edgeDataPath) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 51077bd630..c760ddd6ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -82,7 +82,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 13decefcd6..c8c2823bba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -149,7 +149,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9bd79aa7c6..2f52825c6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -609,7 +609,7 @@ object Word2VecModel extends Loader[Word2VecModel] { case class Data(word: String, vector: Array[Float]) def load(sc: SparkContext, path: String): Word2VecModel = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataFrame = spark.read.parquet(Loader.dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -620,7 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val vectorSize = model.values.head.length val numWords = model.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 8c0639baea..0f7fbe9556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def save(model: FPGrowthModel[_], path: String): Unit = { val sc = model.freqItemsets.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -123,7 +123,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def load(sc: SparkContext, path: String): FPGrowthModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 10bbcd2a3d..c13c794775 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def save(model: PrefixSpanModel[_], path: String): Unit = { val sc = model.freqSequences.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -640,7 +640,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 450025f477..c642573ccb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -354,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -365,7 +365,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 215a799b96..1cd6f2a896 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 3c7bbc5244..cd90e97cc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -68,7 +68,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val dataPath = Loader.dataPath(path) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 72663188a9..a1562384b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -233,13 +233,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Create Parquet data. val nodes = model.topNode.subtreeIterator.toSeq val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _)) - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { // Load Parquet data. - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val dataPath = Loader.dataPath(path) val dataRDD = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index c653b988e2..f7d9b22b6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -413,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // SPARK-6120: We do a hacky check here so users understand why save() is failing // when they run the ML guide example. @@ -471,7 +471,7 @@ private[tree] object TreeEnsembleModel extends Logging { sc: SparkContext, path: String, treeAlgo: String): Array[DecisionTreeModel] = { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() import spark.implicits._ val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply) val trees = constructTrees(nodes.rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 40d5b4881f..3558290b23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val spark = SparkSession.builder - .master("local[2]") - .appName("ChiSqSelectorSuite") - .getOrCreate() + val spark = this.spark import spark.implicits._ - val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 621c13a8e5..b73dbd6232 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,7 +27,7 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test observed number of buckets and their sizes match expected values") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val datasetSize = 100000 @@ -53,7 +53,7 @@ class QuantileDiscretizerSuite } test("Test transform method on unseen data") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") @@ -82,7 +82,7 @@ class QuantileDiscretizerSuite } test("Verify resulting model has parent") { - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = this.spark import spark.implicits._ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 59b5edc401..e8ed50acf8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite { val spark = SparkSession.builder .master("local[2]") .appName("ALSCleanerSuite") + .sparkContext(sc) .getOrCreate() import spark.implicits._ val als = new ALS() @@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite { val pattern = "shuffle_(\\d+)_.+\\.data".r val rddIds = resultingFiles.flatMap { f => pattern.findAllIn(f.getName()).matchData.map { _.group(1) } } - assert(rddIds.toSet.size === 4) + assert(rddIds.size === 4) } finally { sc.stop() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 8cbd652bac..d2fa8d0d63 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val spark = SparkSession.builder + val spark = SparkSession.builder() .master("local[2]") .appName("TreeTests") + .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 20e22baa35..dc4b72a6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -694,7 +694,7 @@ object SparkSession { private[this] var userSuppliedContext: Option[SparkContext] = None - private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { userSuppliedContext = Option(sparkContext) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index db32b6b6be..97adffa8ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -1,25 +1,25 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.joins import scala.reflect.ClassTag -import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan @@ -44,11 +44,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { */ override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf() - .setMaster("local-cluster[2,1,1024]") - .setAppName("testing") - val sc = new SparkContext(conf) - spark = SparkSession.builder.getOrCreate() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() } override def afterAll(): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a4bbe96cf8..d56bede0cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} import org.apache.spark.sql.expressions.Window @@ -282,15 +282,12 @@ object SetWarehouseLocationTest extends Logging { val hiveWarehouseLocation = Utils.createTempDir() hiveWarehouseLocation.delete() - val conf = new SparkConf() - conf.set("spark.ui.enabled", "false") // We will use the value of spark.sql.warehouse.dir override the // value of hive.metastore.warehouse.dir. - conf.set("spark.sql.warehouse.dir", warehouseLocation.toString) - conf.set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) - - val sparkSession = SparkSession.builder - .config(conf) + val sparkSession = SparkSession.builder() + .config("spark.ui.enabled", "false") + .config("spark.sql.warehouse.dir", warehouseLocation.toString) + .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) .enableHiveSupport() .getOrCreate() -- cgit v1.2.3