aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-12-16 11:07:54 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-16 11:07:54 -0800
commitd252b2d544a75f6c5523be3492494955050acf50 (patch)
tree49f0e5349765a45116d582570df160bfa8a19bb1 /mllib/src/test/scala/org/apache
parent860dc7f2f8dd01f2562ba83b7af27ba29d91cb62 (diff)
downloadspark-d252b2d544a75f6c5523be3492494955050acf50.tar.gz
spark-d252b2d544a75f6c5523be3492494955050acf50.tar.bz2
spark-d252b2d544a75f6c5523be3492494955050acf50.zip
[SPARK-12309][ML] Use sqlContext from MLlibTestSparkContext for spark.ml test suites
Use ```sqlContext``` from ```MLlibTestSparkContext``` rather than creating new one for spark.ml test suites. I have checked thoroughly and found there are four test cases need to update. cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #10279 from yanboliang/spark-12309.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala3
5 files changed, 5 insertions, 11 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 09183fe65b..035bfc07b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("MinMaxScaler fit basic case") {
- val sqlContext = new SQLContext(sc)
-
val data = Array(
Vectors.dense(1, 0, Long.MinValue),
Vectors.dense(2, 0, 0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index de3d438ce8..4688339019 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.sparse(3, Seq())
)
- val sqlContext = new SQLContext(sc)
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 74706a23e0..8acc3369c4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
}
test("Test vector slicer") {
- val sqlContext = new SQLContext(sc)
-
val data = Array(
Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 460849c79f..4e2d0e93bd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
data: RDD[LabeledPoint],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
- val sqlContext = new SQLContext(data.sparkContext)
+ val sqlContext = SQLContext.getOrCreate(data.sparkContext)
import sqlContext.implicits._
val df = data.toDF()
val numFeatures = data.first().features.size
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index dd6366050c..d281084f91 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
class CrossValidatorSuite
@@ -39,7 +39,6 @@ class CrossValidatorSuite
override def beforeAll(): Unit = {
super.beforeAll()
- val sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}