aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAndrew Tulloch <andrew@tullo.ch>2014-01-19 17:51:00 +0000
committerAndrew Tulloch <andrew@tullo.ch>2014-01-19 17:51:00 +0000
commit720836a76169be4f7ad010c6f6dbb54666b6aba1 (patch)
treebc2fc2dc6bc28b9f0223c8a4d5dae5bd8d71d884 /mllib
parentfe8a3546f40394466a41fc750cb60f6fc73d8bbb (diff)
downloadspark-720836a76169be4f7ad010c6f6dbb54666b6aba1.tar.gz
spark-720836a76169be4f7ad010c6f6dbb54666b6aba1.tar.bz2
spark-720836a76169be4f7ad010c6f6dbb54666b6aba1.zip
LocalSparkContext for MLlib
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala23
10 files changed, 42 insertions, 109 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 02ede71137..155788f85e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.mllib.util.LocalSparkContext
import scala.util.Random
import scala.collection.JavaConversions._
@@ -66,19 +67,7 @@ object LogisticRegressionSuite {
}
-class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
-
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index b615f76e66..375196d97f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.mllib.util.LocalSparkContext
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
@@ -59,17 +60,7 @@ object NaiveBayesSuite {
}
}
-class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class NaiveBayesSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 3357b86f9b..bc7abb568a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -25,8 +25,9 @@ import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
-import org.apache.spark.{SparkException, SparkContext}
+import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
object SVMSuite {
@@ -58,17 +59,7 @@ object SVMSuite {
}
-class SVMSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class SVMSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 73657cac89..4ef1d1f64f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
-
-class KMeansSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class KMeansSuite extends FunSuite with LocalSparkContext {
val EPSILON = 1e-4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a6028a1e98..a453de6767 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
object GradientDescentSuite {
@@ -62,17 +63,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 4e8dbde658..28a27b1d09 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
import org.jblas._
@@ -73,17 +73,7 @@ object ALSSuite {
}
-class ALSSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class ALSSuite extends FunSuite with LocalSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index b2c8df97a8..64e4cbb860 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-
-class LassoSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class LassoSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 406afbaa3e..648d89bc8c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -21,19 +21,9 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
-
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class LinearRegressionSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 1d6a10b66e..7ec53eb772 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -23,19 +23,10 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
- @transient private var sc: SparkContext = _
- override def beforeAll() {
- sc = new SparkContext("local", "test")
- }
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
- }
+class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
predictions.zip(input).map { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
new file mode 100644
index 0000000000..7d840043e5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -0,0 +1,23 @@
+package org.apache.spark.mllib.util
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkContext
+
+trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (sc != null) {
+ sc.stop()
+ }
+ System.clearProperty("spark.driver.port")
+ super.afterAll()
+ }
+}