aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-20 20:30:39 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-20 20:30:39 -0700
commitddec173cba63df723cd94508121d8c06d8c153c6 (patch)
tree6732b44d12440832ebf30bcddf8936781b9b728d /mllib
parent42c592adb381ff20832cce55e0849ed68dd7eee4 (diff)
downloadspark-ddec173cba63df723cd94508121d8c06d8c153c6.tar.gz
spark-ddec173cba63df723cd94508121d8c06d8c153c6.tar.bz2
spark-ddec173cba63df723cd94508121d8c06d8c153c6.zip
[SPARK-7774] [MLLIB] add sqlContext to MLlibTestSparkContext
to simplify test suites that require a SQLContext. Author: Xiangrui Meng <meng@databricks.com> Closes #6303 from mengxr/SPARK-7774 and squashes the following commits: 0622b5a [Xiangrui Meng] update some other test suites e1f9b8d [Xiangrui Meng] add sqlContext to MLlibTestSparkContext
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala8
14 files changed, 20 insertions, 79 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 97f9749cb4..9f77d5f3ef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -23,18 +23,16 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
- sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 990cfb08af..770b56890f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,24 +21,23 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.util.MetadataUtils
-import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
override def beforeAll(): Unit = {
super.beforeAll()
- sqlContext = new SQLContext(sc)
+
val nPoints = 1000
// The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index caf1b75959..8f6c6b39dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -20,18 +20,14 @@ package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
+import org.apache.spark.sql.{DataFrame, Row}
class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
@transient var data: Array[Double] = _
- @transient var sqlContext: SQLContext = _
override def beforeAll(): Unit = {
super.beforeAll()
- sqlContext = new SQLContext(sc)
data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 20d2f3ac66..0391bd8427 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -25,17 +25,10 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
- @transient private var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
-
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index eaee3443c1..f85e854716 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -22,17 +22,10 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
class IDFSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
-
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
case data: DenseVector =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 92ec407b98..056b9eda86 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -21,16 +21,10 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
- private var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index c1d64fba0a..aa230ca073 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -18,22 +18,15 @@
package org.apache.spark.ml.feature
import org.scalatest.FunSuite
+import org.scalatest.exceptions.TestFailedException
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext}
-import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.sql.Row
class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
-
test("Polynomial expansion with default parameter") {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index b6939e5870..89c2fe4557 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,15 +21,8 @@ import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SQLContext
class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
- private var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index d186ead8f5..a46d08d651 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -22,7 +22,7 @@ import scala.beans.BeanInfo
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
@@ -30,13 +30,6 @@ case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
- @transient var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
-
test("RegexTokenizer") {
val tokenizer = new RegexTokenizer()
.setInputCol("rawText")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 0db27607bc..d0cd62c5e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -22,17 +22,10 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- sqlContext = new SQLContext(sc)
- }
-
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 38dc83b124..b11b029c63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -26,15 +26,12 @@ import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
-
+import org.apache.spark.sql.DataFrame
class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
import VectorIndexerSuite.FeatureData
- @transient var sqlContext: SQLContext = _
-
// identical, of length 3
@transient var densePoints1: DataFrame = _
@transient var sparsePoints1: DataFrame = _
@@ -86,7 +83,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
checkPair(densePoints1Seq, sparsePoints1Seq)
checkPair(densePoints2Seq, sparsePoints2Seq)
- sqlContext = new SQLContext(sc)
densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData))
sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData))
densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 6cc6ec94eb..9a35555e52 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -38,14 +38,12 @@ import org.apache.spark.util.Utils
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
- private var sqlContext: SQLContext = _
private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Utils.createTempDir()
sc.setCheckpointDir(tempDir.getAbsolutePath)
- sqlContext = new SQLContext(sc)
}
override def afterAll(): Unit = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 80323ef520..50a78631fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,11 +22,10 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{Row, SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
- @transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
/**
@@ -41,7 +40,6 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
*/
override def beforeAll(): Unit = {
super.beforeAll()
- sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index b658889476..5d1796ef65 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.util
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.SQLContext
trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
+ @transient var sqlContext: SQLContext = _
override def beforeAll() {
super.beforeAll()
@@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
.setMaster("local[2]")
.setAppName("MLlibUnitTest")
sc = new SparkContext(conf)
+ sqlContext = new SQLContext(sc)
}
override def afterAll() {
+ sqlContext = null
if (sc != null) {
sc.stop()
}
+ sc = null
super.afterAll()
}
}