aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-05-31 17:40:44 -0700
committerAndrew Or <andrew@databricks.com>2016-05-31 17:40:44 -0700
commit85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31 (patch)
tree2e09a7e6c626ec965d86b31fd3b64207be766349 /mllib/src/test
parent93e97147eb499dde1e54e07ba113eebcbe25508a (diff)
downloadspark-85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31.tar.gz
spark-85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31.tar.bz2
spark-85d6b0db9f5bd425c36482ffcb1c3b9fd0fcdb31.zip
[SPARK-15618][SQL][MLLIB] Use SparkSession.builder.sparkContext if applicable.
## What changes were proposed in this pull request? This PR changes function `SparkSession.builder.sparkContext(..)` from **private[sql]** into **private[spark]**, and uses it if applicable like the followings. ``` - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() ``` ## How was this patch tested? Pass the existing Jenkins tests. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13365 from dongjoon-hyun/SPARK-15618.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala3
4 files changed, 9 insertions, 11 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 40d5b4881f..3558290b23 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.Row
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
test("Test Chi-Square selector") {
- val spark = SparkSession.builder
- .master("local[2]")
- .appName("ChiSqSelectorSuite")
- .getOrCreate()
+ val spark = this.spark
import spark.implicits._
-
val data = Seq(
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 621c13a8e5..b73dbd6232 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -27,7 +27,7 @@ class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("Test observed number of buckets and their sizes match expected values") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val datasetSize = 100000
@@ -53,7 +53,7 @@ class QuantileDiscretizerSuite
}
test("Test transform method on unseen data") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
@@ -82,7 +82,7 @@ class QuantileDiscretizerSuite
}
test("Verify resulting model has parent") {
- val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
+ val spark = this.spark
import spark.implicits._
val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 59b5edc401..e8ed50acf8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite {
val spark = SparkSession.builder
.master("local[2]")
.appName("ALSCleanerSuite")
+ .sparkContext(sc)
.getOrCreate()
import spark.implicits._
val als = new ALS()
@@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite {
val pattern = "shuffle_(\\d+)_.+\\.data".r
val rddIds = resultingFiles.flatMap { f =>
pattern.findAllIn(f.getName()).matchData.map { _.group(1) } }
- assert(rddIds.toSet.size === 4)
+ assert(rddIds.size === 4)
} finally {
sc.stop()
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 8cbd652bac..d2fa8d0d63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite {
data: RDD[LabeledPoint],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
- val spark = SparkSession.builder
+ val spark = SparkSession.builder()
.master("local[2]")
.appName("TreeTests")
+ .sparkContext(data.sparkContext)
.getOrCreate()
import spark.implicits._