aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala7
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala41
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala13
28 files changed, 79 insertions, 90 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index c50f25d951..a68fd0285f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -29,13 +29,10 @@ object BroadcastTest {
val blockSize = if (args.length > 2) args(2) else "4096"
- val sparkConf = new SparkConf()
- .set("spark.broadcast.blockSize", blockSize)
-
val spark = SparkSession
- .builder
- .config(sparkConf)
+ .builder()
.appName("Broadcast Test")
+ .config("spark.broadcast.blockSize", blockSize)
.getOrCreate()
val sc = spark.sparkContext
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 7651aade49..3fbf8e0333 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -191,6 +191,7 @@ object LDAExample {
val spark = SparkSession
.builder
+ .sparkContext(sc)
.getOrCreate()
import spark.implicits._
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index d3bb7e4398..2d7a01a95d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -22,7 +22,6 @@ import java.io.File
import com.google.common.io.{ByteStreams, Files}
-import org.apache.spark.SparkConf
import org.apache.spark.sql._
object HiveFromSpark {
@@ -35,8 +34,6 @@ object HiveFromSpark {
ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File))
def main(args: Array[String]) {
- val sparkConf = new SparkConf().setAppName("HiveFromSpark")
-
// When working with Hive, one must instantiate `SparkSession` with Hive support, including
// connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined
// functions. Users who do not have an existing Hive deployment can still enable Hive support.
@@ -45,7 +42,7 @@ object HiveFromSpark {
// which defaults to the directory `spark-warehouse` in the current directory that the spark
// application is started.
val spark = SparkSession.builder
- .config(sparkConf)
+ .appName("HiveFromSpark")
.enableHiveSupport()
.getOrCreate()
val sc = spark.sparkContext
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6e0ed374c7..e43469bf1c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable {
// We use DataFrames for serialization of IndexedRows to Python,
// so return a DataFrame.
val sc = indexedRowMatrix.rows.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
spark.createDataFrame(indexedRowMatrix.rows)
}
@@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable {
// We use DataFrames for serialization of MatrixEntry entries to
// Python, so return a DataFrame.
val sc = coordinateMatrix.entries.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
spark.createDataFrame(coordinateMatrix.entries)
}
@@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable {
// We use DataFrames for serialization of sub-matrix blocks to
// Python, so return a DataFrame.
val sc = blockMatrix.blocks.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
spark.createDataFrame(blockMatrix.blocks)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index b186ca3770..e4cc784cfe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS
lr.setMaxIter(optimizer.getNumIterations())
lr.setTol(optimizer.getConvergenceTol())
// Convert our input into a DataFrame
- val spark = SparkSession.builder().config(input.context.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
val df = spark.createDataFrame(input.map(_.asML))
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 452802f043..593a86f69a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
modelType: String)
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
@Since("1.3.0")
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Load Parquet data.
val dataRDD = spark.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
@@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
theta: Array[Array[Double]])
def save(sc: SparkContext, path: String, data: Data): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
}
def load(sc: SparkContext, path: String): NaiveBayesModel = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Load Parquet data.
val dataRDD = spark.read.parquet(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 32e323d080..84491181d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
weights: Vector,
intercept: Double,
threshold: Option[Double]): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel {
*/
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
val dataPath = Loader.dataPath(path)
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataRDD = spark.read.parquet(dataPath)
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 510a91b5a7..b3546a1ee3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
~ ("rootId" -> model.root.index)))
@@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
}
def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val rows = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Data](rows.schema)
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 31ad56dba6..c30cc3e239 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
path: String,
weights: Array[Double],
gaussians: Array[MultivariateGaussian]): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Create JSON metadata.
val metadata = compact(render
@@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataFrame = spark.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 5f939c1a21..aa78149699 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] {
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
@@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] {
def load(sc: SparkContext, path: String): KMeansModel = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 1b66013d54..4f07236225 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -446,7 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
docConcentration: Vector,
topicConcentration: Double,
gammaShape: Double): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val k = topicsMatrix.numCols
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -470,7 +470,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
topicConcentration: Double,
gammaShape: Double): LocalLDAModel = {
val dataPath = Loader.dataPath(path)
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataFrame = spark.read.parquet(dataPath)
Loader.checkSchema[Data](dataFrame.schema)
@@ -851,7 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
topicConcentration: Double,
iterationTimes: Array[Double],
gammaShape: Double): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -887,7 +887,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataFrame = spark.read.parquet(dataPath)
val vertexDataFrame = spark.read.parquet(vertexDataPath)
val edgeDataFrame = spark.read.parquet(edgeDataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 51077bd630..c760ddd6ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
@Since("1.4.0")
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
@@ -82,7 +82,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
@Since("1.4.0")
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 13decefcd6..c8c2823bba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -149,7 +149,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 9bd79aa7c6..2f52825c6c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -609,7 +609,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
case class Data(word: String, vector: Array[Float])
def load(sc: SparkContext, path: String): Word2VecModel = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataFrame = spark.read.parquet(Loader.dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
@@ -620,7 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
}
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val vectorSize = model.values.head.length
val numWords = model.size
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 8c0639baea..0f7fbe9556 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
def save(model: FPGrowthModel[_], path: String): Unit = {
val sc = model.freqItemsets.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -123,7 +123,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 10bbcd2a3d..c13c794775 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
def save(model: PrefixSpanModel[_], path: String): Unit = {
val sc = model.freqSequences.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
@@ -640,7 +640,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 450025f477..c642573ccb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -354,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
*/
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
import spark.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
@@ -365,7 +365,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
implicit val formats = DefaultFormats
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 215a799b96..1cd6f2a896 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
@@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
}
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataRDD = spark.read.parquet(dataPath(path))
checkSchema[Data](dataRDD.schema)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 3c7bbc5244..cd90e97cc5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
modelClass: String,
weights: Vector,
intercept: Double): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// Create JSON metadata.
val metadata = compact(render(
@@ -68,7 +68,7 @@ private[regression] object GLMRegressionModel {
*/
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
val dataPath = Loader.dataPath(path)
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataRDD = spark.read.parquet(dataPath)
val dataArray = dataRDD.select("weights", "intercept").take(1)
assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 72663188a9..a1562384b0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -233,13 +233,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
// Create Parquet data.
val nodes = model.topNode.subtreeIterator.toSeq
val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _))
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
// Load Parquet data.
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val dataPath = Loader.dataPath(path)
val dataRDD = spark.read.parquet(dataPath)
// Check schema explicitly since erasure makes it hard to use match-case for checking.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index c653b988e2..f7d9b22b6f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -413,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging {
case class EnsembleNodeData(treeId: Int, node: NodeData)
def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
@@ -471,7 +471,7 @@ private[tree] object TreeEnsembleModel extends Logging {
sc: SparkContext,
path: String,
treeAlgo: String): Array[DecisionTreeModel] = {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
import spark.implicits._
val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply)
val trees = constructTrees(nodes.rdd)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 40d5b4881f..3558290b23 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.Row
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
test("Test Chi-Square selector") {
- val spark = SparkSession.builder
- .master("local[2]")
- .appName("ChiSqSelectorSuite")
- .getOrCreate()
+ val spark = this.spark
import spark.implicits._
-
val data = Seq(
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 621c13a8e5..b73dbd6232 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -27,7 +27,7 @@ class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("Test observed number of buckets and their sizes match expected values") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val datasetSize = 100000
@@ -53,7 +53,7 @@ class QuantileDiscretizerSuite
}
test("Test transform method on unseen data") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
@@ -82,7 +82,7 @@ class QuantileDiscretizerSuite
}
test("Verify resulting model has parent") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 59b5edc401..e8ed50acf8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite {
val spark = SparkSession.builder
.master("local[2]")
.appName("ALSCleanerSuite")
+ .sparkContext(sc)
.getOrCreate()
import spark.implicits._
val als = new ALS()
@@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite {
val pattern = "shuffle_(\\d+)_.+\\.data".r
val rddIds = resultingFiles.flatMap { f =>
pattern.findAllIn(f.getName()).matchData.map { _.group(1) } }
- assert(rddIds.toSet.size === 4)
+ assert(rddIds.size === 4)
} finally {
sc.stop()
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 8cbd652bac..d2fa8d0d63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite {
data: RDD[LabeledPoint],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
- val spark = SparkSession.builder
+ val spark = SparkSession.builder()
.master("local[2]")
.appName("TreeTests")
+ .sparkContext(data.sparkContext)
.getOrCreate()
import spark.implicits._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 20e22baa35..dc4b72a6fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -694,7 +694,7 @@ object SparkSession {
private[this] var userSuppliedContext: Option[SparkContext] = None
- private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
+ private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
userSuppliedContext = Option(sparkContext)
this
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index db32b6b6be..97adffa8ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -1,25 +1,25 @@
/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package org.apache.spark.sql.execution.joins
import scala.reflect.ClassTag
-import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.SparkPlan
@@ -44,11 +44,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
*/
override def beforeAll(): Unit = {
super.beforeAll()
- val conf = new SparkConf()
- .setMaster("local-cluster[2,1,1024]")
- .setAppName("testing")
- val sc = new SparkContext(conf)
- spark = SparkSession.builder.getOrCreate()
+ spark = SparkSession.builder()
+ .master("local-cluster[2,1,1024]")
+ .appName("testing")
+ .getOrCreate()
}
override def afterAll(): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index a4bbe96cf8..d56bede0cc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
+import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource}
import org.apache.spark.sql.expressions.Window
@@ -282,15 +282,12 @@ object SetWarehouseLocationTest extends Logging {
val hiveWarehouseLocation = Utils.createTempDir()
hiveWarehouseLocation.delete()
- val conf = new SparkConf()
- conf.set("spark.ui.enabled", "false")
// We will use the value of spark.sql.warehouse.dir override the
// value of hive.metastore.warehouse.dir.
- conf.set("spark.sql.warehouse.dir", warehouseLocation.toString)
- conf.set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString)
-
- val sparkSession = SparkSession.builder
- .config(conf)
+ val sparkSession = SparkSession.builder()
+ .config("spark.ui.enabled", "false")
+ .config("spark.sql.warehouse.dir", warehouseLocation.toString)
+ .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString)
.enableHiveSupport()
.getOrCreate()