aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-05-27 11:09:15 -0700
committerAndrew Or <andrew@databricks.com>2016-05-27 11:09:15 -0700
commitd24e251572d39a453293cabfe14e4aed25a55208 (patch)
tree448a5e9a9c4ecd553b180c34f82529cacb09592b /mllib
parentc17272902c95290beca274ee6316a8a98fd7a725 (diff)
downloadspark-d24e251572d39a453293cabfe14e4aed25a55208.tar.gz
spark-d24e251572d39a453293cabfe14e4aed25a55208.tar.bz2
spark-d24e251572d39a453293cabfe14e4aed25a55208.zip
[SPARK-15603][MLLIB] Replace SQLContext with SparkSession in ML/MLLib
## What changes were proposed in this pull request? This PR replaces all deprecated `SQLContext` occurrences with `SparkSession` in `ML/MLLib` module except the following two classes. These two classes use `SQLContext` in their function signatures. - ReadWrite.scala - TreeModels.scala ## How was this patch tested? Pass the existing Jenkins tests. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13352 from dongjoon-hyun/SPARK-15603.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala6
23 files changed, 160 insertions, 195 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88b6b27e62..773e50e245 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.clustering
import breeze.linalg.{DenseVector => BDV}
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.EPSILON
@@ -33,7 +32,7 @@ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -134,9 +133,7 @@ class GaussianMixtureModel private[ml] (
val modelGaussians = gaussians.map { gaussian =>
(OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov))
}
- val sc = SparkContext.getOrCreate()
- val sqlContext = SQLContext.getOrCreate(sc)
- sqlContext.createDataFrame(modelGaussians).toDF("mean", "cov")
+ SparkSession.builder().getOrCreate().createDataFrame(modelGaussians).toDF("mean", "cov")
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 2d4cac6dc4..bd8f9494fb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.StructType
/**
@@ -74,15 +73,14 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- val sc = SparkContext.getOrCreate()
- val sqlContext = SQLContext.getOrCreate(sc)
- val dummyRDD = sc.parallelize(Seq(Row.empty))
- val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
+ val spark = SparkSession.builder().getOrCreate()
+ val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty))
+ val dummyDF = spark.createDataFrame(dummyRDD, schema)
val tableName = Identifiable.randomUID(uid)
val realStatement = $(statement).replace(tableIdentifier, tableName)
dummyDF.createOrReplaceTempView(tableName)
- val outputSchema = sqlContext.sql(realStatement).schema
- sqlContext.dropTempTable(tableName)
+ val outputSchema = spark.sql(realStatement).schema
+ spark.catalog.dropTempView(tableName)
outputSchema
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 1469bfd5e8..1b929cdfff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -183,11 +183,9 @@ class Word2VecModel private[ml] (
* and the vector the DenseVector that it is mapped to.
*/
@transient lazy val getVectors: DataFrame = {
- val sc = SparkContext.getOrCreate()
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().getOrCreate()
val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
- sc.parallelize(wordVec.toSeq).toDF("word", "vector")
+ spark.createDataFrame(wordVec.toSeq).toDF("word", "vector")
}
/**
@@ -205,10 +203,8 @@ class Word2VecModel private[ml] (
* synonyms and the given word vector.
*/
def findSynonyms(word: Vector, num: Int): DataFrame = {
- val sc = SparkContext.getOrCreate()
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
- sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ val spark = SparkSession.builder().getOrCreate()
+ spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}
/** @group setParam */
@@ -230,7 +226,7 @@ class Word2VecModel private[ml] (
val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
- if (sentence.size == 0) {
+ if (sentence.isEmpty) {
Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
} else {
val sum = Vectors.zeros(d)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 667290ece3..6e0ed374c7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -30,8 +30,7 @@ import net.razorvine.pickle._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
-import org.apache.spark.ml.linalg.{DenseMatrix => NewDenseMatrix, DenseVector => NewDenseVector, SparseMatrix => NewSparseMatrix, SparseVector => NewSparseVector, Vector => NewVector, Vectors => NewVectors}
+import org.apache.spark.ml.linalg.{DenseMatrix => NewDenseMatrix, DenseVector => NewDenseVector, SparseMatrix => NewSparseMatrix, SparseVector => NewSparseVector, Vectors => NewVectors}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.evaluation.RankingMetrics
@@ -43,8 +42,7 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.stat.{
- KernelDensity, MultivariateStatisticalSummary, Statistics}
+import org.apache.spark.mllib.stat.{KernelDensity, MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult}
@@ -56,7 +54,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
RandomForestModel}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -1178,8 +1176,9 @@ private[python] class PythonMLLibAPI extends Serializable {
def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
// We use DataFrames for serialization of IndexedRows to Python,
// so return a DataFrame.
- val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext)
- sqlContext.createDataFrame(indexedRowMatrix.rows)
+ val sc = indexedRowMatrix.rows.sparkContext
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ spark.createDataFrame(indexedRowMatrix.rows)
}
/**
@@ -1188,8 +1187,9 @@ private[python] class PythonMLLibAPI extends Serializable {
def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
// We use DataFrames for serialization of MatrixEntry entries to
// Python, so return a DataFrame.
- val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext)
- sqlContext.createDataFrame(coordinateMatrix.entries)
+ val sc = coordinateMatrix.entries.sparkContext
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ spark.createDataFrame(coordinateMatrix.entries)
}
/**
@@ -1198,8 +1198,9 @@ private[python] class PythonMLLibAPI extends Serializable {
def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
// We use DataFrames for serialization of sub-matrix blocks to
// Python, so return a DataFrame.
- val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext)
- sqlContext.createDataFrame(blockMatrix.blocks)
+ val sc = blockMatrix.blocks.sparkContext
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ spark.createDataFrame(blockMatrix.blocks)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 4bba2ea057..b186ca3770 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
/**
@@ -422,7 +422,7 @@ class LogisticRegressionWithLBFGS
LogisticRegressionModel = {
// ml's Logistic regression only supports binary classification currently.
if (numOfLinearPredictor == 1) {
- def runWithMlLogisitcRegression(elasticNetParam: Double) = {
+ def runWithMlLogisticRegression(elasticNetParam: Double) = {
// Prepare the ml LogisticRegression based on our settings
val lr = new org.apache.spark.ml.classification.LogisticRegression()
lr.setRegParam(optimizer.getRegParam())
@@ -437,20 +437,19 @@ class LogisticRegressionWithLBFGS
lr.setMaxIter(optimizer.getNumIterations())
lr.setTol(optimizer.getConvergenceTol())
// Convert our input into a DataFrame
- val sqlContext = SQLContext.getOrCreate(input.context)
- import sqlContext.implicits._
- val df = input.map(_.asML).toDF()
+ val spark = SparkSession.builder().config(input.context.getConf).getOrCreate()
+ val df = spark.createDataFrame(input.map(_.asML))
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
// Train our model
- val mlLogisticRegresionModel = lr.train(df, handlePersistence)
+ val mlLogisticRegressionModel = lr.train(df, handlePersistence)
// convert the model
- val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray)
- createModel(weights, mlLogisticRegresionModel.intercept)
+ val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray)
+ createModel(weights, mlLogisticRegressionModel.intercept)
}
optimizer.getUpdater() match {
- case x: SquaredL2Updater => runWithMlLogisitcRegression(0.0)
- case x: L1Updater => runWithMlLogisitcRegression(1.0)
+ case x: SquaredL2Updater => runWithMlLogisticRegression(0.0)
+ case x: L1Updater => runWithMlLogisticRegression(1.0)
case _ => super.run(input, initialWeights)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index eb3ee41f7c..452802f043 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -31,7 +31,7 @@ import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVect
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.SparkSession
/**
* Model for Naive Bayes Classifiers.
@@ -193,8 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
modelType: String)
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -203,15 +202,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
- dataRDD.write.parquet(dataPath(path))
+ spark.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath(path))
}
@Since("1.3.0")
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Load Parquet data.
- val dataRDD = sqlContext.read.parquet(dataPath(path))
+ val dataRDD = spark.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
@@ -240,8 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
theta: Array[Array[Double]])
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -250,14 +247,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
- dataRDD.write.parquet(dataPath(path))
+ spark.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath(path))
}
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Load Parquet data.
- val dataRDD = sqlContext.read.parquet(dataPath(path))
+ val dataRDD = spark.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
@@ -327,7 +323,7 @@ class NaiveBayes private (
@Since("0.9.0")
def setLambda(lambda: Double): NaiveBayes = {
require(lambda >= 0,
- s"Smoothing parameter must be nonnegative but got ${lambda}")
+ s"Smoothing parameter must be nonnegative but got $lambda")
this.lambda = lambda
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 4308ae04ee..32e323d080 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -23,7 +23,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* Helper class for import/export of GLM classification models.
@@ -51,8 +51,7 @@ private[classification] object GLMClassificationModel {
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -62,7 +61,7 @@ private[classification] object GLMClassificationModel {
// Create Parquet data.
val data = Data(weights, intercept, threshold)
- sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path))
}
/**
@@ -73,13 +72,13 @@ private[classification] object GLMClassificationModel {
* @param modelClass String name for model class (used for error messages)
*/
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
- val datapath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataRDD = sqlContext.read.parquet(datapath)
+ val dataPath = Loader.dataPath(path)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataRDD = spark.read.parquet(dataPath)
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
- assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath")
+ assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
val data = dataArray(0)
- assert(data.size == 3, s"Unable to load $modelClass data from: $datapath")
+ assert(data.size == 3, s"Unable to load $modelClass data from: $dataPath")
val (weights, intercept) = data match {
case Row(weights: Vector, intercept: Double, _) =>
(weights, intercept)
@@ -92,5 +91,4 @@ private[classification] object GLMClassificationModel {
Data(weights, intercept, threshold)
}
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index c3b5b8b790..510a91b5a7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* Clustering model produced by [[BisectingKMeans]].
@@ -144,8 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
~ ("rootId" -> model.root.index)))
@@ -154,8 +153,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
val data = getNodes(model.root).map(node => Data(node.index, node.size,
node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
node.children.map(_.index)))
- val dataRDD = sc.parallelize(data).toDF()
- dataRDD.write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
}
private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
@@ -167,8 +165,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
}
def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
- val sqlContext = SQLContext.getOrCreate(sc)
- val rows = sqlContext.read.parquet(Loader.dataPath(path))
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val rows = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Data](rows.schema)
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index f87613cc72..4b068164b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
@@ -143,9 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
path: String,
weights: Array[Double],
gaussians: Array[MultivariateGaussian]): Unit = {
-
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Create JSON metadata.
val metadata = compact(render
@@ -156,13 +154,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
val dataArray = Array.tabulate(weights.length) { i =>
Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
}
- sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataFrame = sqlContext.read.parquet(dataPath)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataFrame = spark.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
@@ -172,7 +170,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
(weight, new MultivariateGaussian(mu, sigma))
}.unzip
- new GaussianMixtureModel(weights.toArray, gaussians.toArray)
+ new GaussianMixtureModel(weights, gaussians)
}
}
@@ -189,7 +187,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
s"GaussianMixtureModel requires weights of length $k " +
s"got weights of length ${model.weights.length}")
require(model.gaussians.length == k,
- s"GaussianMixtureModel requires gaussians of length $k" +
+ s"GaussianMixtureModel requires gaussians of length $k " +
s"got gaussians of length ${model.gaussians.length}")
model
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 439e4f8672..5f939c1a21 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
@@ -123,25 +123,24 @@ object KMeansModel extends Loader[KMeansModel] {
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
Cluster(id, point)
- }.toDF()
- dataRDD.write.parquet(Loader.dataPath(path))
+ }
+ spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
- val centroids = sqlContext.read.parquet(Loader.dataPath(path))
+ val centroids = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Cluster](centroids.schema)
val localCentroids = centroids.rdd.map(Cluster.apply).collect()
assert(k == localCentroids.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 4913c0287a..0a515f893d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -31,7 +31,7 @@ import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.util.BoundedPriorityQueue
/**
@@ -446,9 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
docConcentration: Vector,
topicConcentration: Double,
gammaShape: Double): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
-
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -461,8 +459,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
val topics = Range(0, k).map { topicInd =>
Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
- }.toSeq
- sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
+ }
+ spark.createDataFrame(topics).repartition(1).write.parquet(Loader.dataPath(path))
}
def load(
@@ -472,8 +470,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
topicConcentration: Double,
gammaShape: Double): LocalLDAModel = {
val dataPath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataFrame = sqlContext.read.parquet(dataPath)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataFrame = spark.read.parquet(dataPath)
Loader.checkSchema[Data](dataFrame.schema)
val topics = dataFrame.collect()
@@ -853,8 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
topicConcentration: Double,
iterationTimes: Array[Double],
gammaShape: Double): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -866,18 +863,17 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
- sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
- .write.parquet(newPath)
+ spark.createDataFrame(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).write.parquet(newPath)
val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
- graph.vertices.map { case (ind, vertex) =>
+ spark.createDataFrame(graph.vertices.map { case (ind, vertex) =>
VertexData(ind, Vectors.fromBreeze(vertex))
- }.toDF().write.parquet(verticesPath)
+ }).write.parquet(verticesPath)
val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
- graph.edges.map { case Edge(srcId, dstId, prop) =>
+ spark.createDataFrame(graph.edges.map { case Edge(srcId, dstId, prop) =>
EdgeData(srcId, dstId, prop)
- }.toDF().write.parquet(edgesPath)
+ }).write.parquet(edgesPath)
}
def load(
@@ -891,10 +887,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataFrame = sqlContext.read.parquet(dataPath)
- val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
- val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataFrame = spark.read.parquet(dataPath)
+ val vertexDataFrame = spark.read.parquet(vertexDataPath)
+ val edgeDataFrame = spark.read.parquet(edgeDataPath)
Loader.checkSchema[Data](dataFrame.schema)
Loader.checkSchema[VertexData](vertexDataFrame.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 2e257ff9b7..51077bd630 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -70,28 +70,26 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
@Since("1.4.0")
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
- val dataRDD = model.assignments.toDF()
- dataRDD.write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(model.assignments).write.parquet(Loader.dataPath(path))
}
@Since("1.4.0")
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
- val assignments = sqlContext.read.parquet(Loader.dataPath(path))
+ val assignments = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
val assignmentsRDD = assignments.rdd.map {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 4f0e13feae..13decefcd6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
@@ -134,8 +134,8 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
@@ -144,18 +144,17 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
Data(model.selectedFeatures(i))
}
- sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
-
+ spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
- val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
+ val dataFrame = spark.read.parquet(Loader.dataPath(path))
val dataArray = dataFrame.select("feature")
// Check schema explicitly since erasure makes it hard to use match-case for checking.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 7e6c367970..9bd79aa7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -609,9 +609,8 @@ object Word2VecModel extends Loader[Word2VecModel] {
case class Data(word: String, vector: Array[Float])
def load(sc: SparkContext, path: String): Word2VecModel = {
- val dataPath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataFrame = sqlContext.read.parquet(dataPath)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataFrame = spark.read.parquet(Loader.dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
@@ -621,9 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
}
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
-
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val vectorSize = model.values.head.length
val numWords = model.size
@@ -641,7 +638,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
val approxSize = 4L * numWords * vectorSize
val nPartitions = ((approxSize / partitionSize) + 1).toInt
val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
- sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 28e4966f91..8c0639baea 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.fpm.FPGrowth._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
def save(model: FPGrowthModel[_], path: String): Unit = {
val sc = model.freqItemsets.sparkContext
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -118,18 +118,18 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
val rowDataRDD = model.freqItemsets.map { x =>
Row(x.items.toSeq, x.freq)
}
- sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
- val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
+ val freqItemsets = spark.read.parquet(Loader.dataPath(path))
val sample = freqItemsets.select("items").head().get(0)
loadImpl(freqItemsets, sample)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 4344ab1bad..10bbcd2a3d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -36,7 +36,7 @@ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
def save(model: PrefixSpanModel[_], path: String): Unit = {
val sc = model.freqSequences.sparkContext
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -635,18 +635,18 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
val rowDataRDD = model.freqSequences.map { x =>
Row(x.sequence, x.freq)
}
- sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
- val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
+ val freqSequences = spark.read.parquet(Loader.dataPath(path))
val sample = freqSequences.select("sequence").head().get(0)
loadImpl(freqSequences, sample)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 6f780b0da7..450025f477 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.storage.StorageLevel
/**
@@ -354,8 +354,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
*/
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ import spark.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
@@ -365,16 +365,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
implicit val formats = DefaultFormats
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = (metadata \ "rank").extract[Int]
- val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map {
+ val userFeatures = spark.read.parquet(userPath(path)).rdd.map {
case Row(id: Int, features: Seq[_]) =>
(id, features.asInstanceOf[Seq[Double]].toArray)
}
- val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map {
+ val productFeatures = spark.read.parquet(productPath(path)).rdd.map {
case Row(id: Int, features: Seq[_]) =>
(id, features.asInstanceOf[Seq[Double]].toArray)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index abdd798197..215a799b96 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -34,7 +34,7 @@ import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
/**
* Regression model for isotonic regression.
@@ -185,21 +185,21 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
- sqlContext.createDataFrame(
+ spark.createDataFrame(
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
).write.parquet(dataPath(path))
}
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataRDD = sqlContext.read.parquet(dataPath(path))
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataRDD = spark.read.parquet(dataPath(path))
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundary", "prediction").collect()
@@ -221,7 +221,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
case _ => throw new Exception(
- s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
+ s"IsotonicRegressionModel.load did not recognize model with (className, format version): " +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)"
)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 7696fdf2dc..3c7bbc5244 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -23,7 +23,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
/**
* Helper methods for import/export of GLM regression models.
@@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
modelClass: String,
weights: Vector,
intercept: Double): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -57,7 +57,7 @@ private[regression] object GLMRegressionModel {
// Create Parquet data.
val data = Data(weights, intercept)
- sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path))
+ spark.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path))
}
/**
@@ -67,17 +67,17 @@ private[regression] object GLMRegressionModel {
* The length of the weights vector should equal numFeatures.
*/
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
- val datapath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val dataRDD = sqlContext.read.parquet(datapath)
+ val dataPath = Loader.dataPath(path)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataRDD = spark.read.parquet(dataPath)
val dataArray = dataRDD.select("weights", "intercept").take(1)
- assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath")
+ assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
val data = dataArray(0)
- assert(data.size == 2, s"Unable to load $modelClass data from: $datapath")
+ assert(data.size == 2, s"Unable to load $modelClass data from: $dataPath")
data match {
case Row(weights: Vector, intercept: Double) =>
assert(weights.size == numFeatures, s"Expected $numFeatures features, but" +
- s" found ${weights.size} features when loading $modelClass weights from $datapath")
+ s" found ${weights.size} features when loading $modelClass weights from $dataPath")
Data(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c13b9a66c4..72663188a9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.util.Utils
/**
@@ -202,9 +202,6 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
}
def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
-
// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
// TODO: Fix this issue for real.
@@ -235,17 +232,16 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
// Create Parquet data.
val nodes = model.topNode.subtreeIterator.toSeq
- val dataRDD: DataFrame = sc.parallelize(nodes)
- .map(NodeData.apply(0, _))
- .toDF()
- dataRDD.write.parquet(Loader.dataPath(path))
+ val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _))
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
- val datapath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
// Load Parquet data.
- val dataRDD = sqlContext.read.parquet(datapath)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val dataPath = Loader.dataPath(path)
+ val dataRDD = spark.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[NodeData](dataRDD.schema)
val nodes = dataRDD.rdd.map(NodeData.apply)
@@ -254,7 +250,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
assert(trees.length == 1,
"Decision tree should contain exactly one tree but got ${trees.size} trees.")
val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
- assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
+ assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $dataPath." +
s" Expected $numNodes nodes but found ${model.numNodes}")
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index cbf49b6d58..c653b988e2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
/**
@@ -413,8 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging {
case class EnsembleNodeData(treeId: Int, node: NodeData)
def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
- val sqlContext = SQLContext.getOrCreate(sc)
- import sqlContext.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
@@ -450,8 +449,8 @@ private[tree] object TreeEnsembleModel extends Logging {
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
- }.toDF()
- dataRDD.write.parquet(Loader.dataPath(path))
+ }
+ spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
/**
@@ -472,10 +471,10 @@ private[tree] object TreeEnsembleModel extends Logging {
sc: SparkContext,
path: String,
treeAlgo: String): Array[DecisionTreeModel] = {
- val datapath = Loader.dataPath(path)
- val sqlContext = SQLContext.getOrCreate(sc)
- val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply)
- val trees = constructTrees(nodes)
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ import spark.implicits._
+ val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply)
+ val trees = constructTrees(nodes.rdd)
trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 8895d630a0..621c13a8e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -20,15 +20,15 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("Test observed number of buckets and their sizes match expected values") {
- val sqlCtx = SQLContext.getOrCreate(sc)
- import sqlCtx.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ import spark.implicits._
val datasetSize = 100000
val numBuckets = 5
@@ -53,8 +53,8 @@ class QuantileDiscretizerSuite
}
test("Test transform method on unseen data") {
- val sqlCtx = SQLContext.getOrCreate(sc)
- import sqlCtx.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ import spark.implicits._
val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
@@ -82,8 +82,8 @@ class QuantileDiscretizerSuite
}
test("Verify resulting model has parent") {
- val sqlCtx = SQLContext.getOrCreate(sc)
- import sqlCtx.implicits._
+ val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ import spark.implicits._
val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index ba8d36f45f..db56aff631 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -21,9 +21,9 @@ import java.io.File
import org.scalatest.Suite
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext
import org.apache.spark.ml.util.TempDirectory
-import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
@@ -46,7 +46,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
override def afterAll() {
try {
Utils.deleteRecursively(new File(checkpointDir))
- SQLContext.clearActive()
+ SparkSession.clearActiveSession()
if (spark != null) {
spark.stop()
}