From 3e7e05f5ee763925ed60410d7de04cf36b723de1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 7 Dec 2015 16:37:09 -0800 Subject: [SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley Closes #10161 from jkbradley/mllib-sqlcontext-fix. --- .../scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- .../scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../spark/mllib/classification/impl/GLMClassificationModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/GaussianMixtureModel.scala | 4 ++-- .../scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ++-- .../apache/spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- .../spark/mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../org/apache/spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../apache/spark/mllib/regression/impl/GLMRegressionModel.scala | 4 ++-- .../org/apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../org/apache/spark/mllib/tree/model/treeEnsembleModels.scala | 4 ++-- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 54b03a9f90..2aa6aec0b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) sqlContext.createDataFrame(indexedRowMatrix.rows) } @@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } @@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) sqlContext.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae0..aef9ef2cb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d..2910c027ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2115f7d99c..74d13e4f77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a741584982..91fa9b0d35 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7cd9b08fa8..bb18045059 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index d4d022afde..eaa99cfe82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b693f3c8e4..23b1514e30 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val vectorSize = model.values.head.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 46562eb2ad..0dc40483dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ec78ea2453..f235089873 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a5702..02af281fb7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 54c136aecf..89c470d573 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 90e032e3d9..3f427f0be3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -408,7 +408,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -468,7 +468,7 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) -- cgit v1.2.3