From d252b2d544a75f6c5523be3492494955050acf50 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 11:07:54 -0800 Subject: [SPARK-12309][ML] Use sqlContext from MLlibTestSparkContext for spark.ml test suites Use ```sqlContext``` from ```MLlibTestSparkContext``` rather than creating new one for spark.ml test suites. I have checked thoroughly and found there are four test cases need to update. cc mengxr jkbradley Author: Yanbo Liang Closes #10279 from yanboliang/spark-12309. --- .../test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala | 4 +--- .../src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala | 3 +-- .../test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala | 4 +--- mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala | 2 +- .../test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 3 +-- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 09183fe65b..035bfc07b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.dense(1, 0, Long.MinValue), Vectors.dense(2, 0, 0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index de3d438ce8..4688339019 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 74706a23e0..8acc3369c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } test("Test vector slicer") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 460849c79f..4e2d0e93bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = new SQLContext(data.sparkContext) + val sqlContext = SQLContext.getOrCreate(data.sparkContext) import sqlContext.implicits._ val df = data.toDF() val numFeatures = data.first().features.size diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index dd6366050c..d281084f91 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType class CrossValidatorSuite @@ -39,7 +39,6 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - val sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } -- cgit v1.2.3