aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2014-08-19 13:28:57 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2014-08-19 13:28:57 -0700
commit31f0b071efd0b63eb9d6a6a131e5c4fa28237583 (patch)
tree2ce7834721e9296510c91673d67b54528709b2eb /mllib
parentcbfc26ba45f49559e64276c72e3054c6fe30ddd5 (diff)
downloadspark-31f0b071efd0b63eb9d6a6a131e5c4fa28237583.tar.gz
spark-31f0b071efd0b63eb9d6a6a131e5c4fa28237583.tar.bz2
spark-31f0b071efd0b63eb9d6a6a131e5c4fa28237583.zip
[SPARK-3128][MLLIB] Use streaming test suite for StreamingLR
Refactored tests for streaming linear regression to use existing streaming test utilities. Summary of changes: - Made ``mllib`` depend on tests from ``streaming`` - Rewrote accuracy and convergence tests to use ``setupStreams`` and ``runStreams`` - Added new test for the accuracy of predictions generated by ``predictOnValue`` These tests should run faster, be easier to extend/maintain, and provide a reference for new tests. mengxr tdas Author: freeman <the.freeman.lab@gmail.com> Closes #2037 from freeman-lab/streamingLR-predict-tests and squashes the following commits: e851ca7 [freeman] Fixed long lines 50eb0bf [freeman] Refactored tests to use streaming test tools 32c43c2 [freeman] Added test for prediction
Diffstat (limited to 'mllib')
-rw-r--r--mllib/pom.xml7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala121
2 files changed, 74 insertions, 54 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fc1ecfbea7..c7a1e2ae75 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -91,6 +91,13 @@
<artifactId>junit-interface</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<profiles>
<profile>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 45e25eecf5..28489410f8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -17,20 +17,19 @@
package org.apache.spark.mllib.regression
-import java.io.File
-import java.nio.charset.Charset
-
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.TestSuiteBase
+
+class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
-class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
+ // use longer wait time to ensure job completion
+ override def maxWaitTimeMillis = 20000
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
- test("streaming linear regression parameter accuracy") {
+ test("parameter accuracy") {
- val testDir = Files.createTempDir()
- val numBatches = 10
- val batchDuration = Milliseconds(1000)
- val ssc = new StreamingContext(sc, batchDuration)
- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse)
+ // create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.1)
- .setNumIterations(50)
+ .setNumIterations(25)
- model.trainOn(data)
-
- ssc.start()
-
- // write data to a file stream
- for (i <- 0 until numBatches) {
- val samples = LinearDataGenerator.generateLinearInput(
- 0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
- val file = new File(testDir, i.toString)
- Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
- Thread.sleep(batchDuration.milliseconds)
+ // generate sequence of simulated data
+ val numBatches = 10
+ val input = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
}
- ssc.stop(stopSparkContext=false)
-
- System.clearProperty("spark.driver.port")
- Utils.deleteRecursively(testDir)
+ // apply model training to input stream
+ val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
// check accuracy of final parameter estimates
assertEqual(model.latestModel().intercept, 0.0, 0.1)
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
- test("streaming linear regression parameter convergence") {
+ test("parameter convergence") {
- val testDir = Files.createTempDir()
- val batchDuration = Milliseconds(2000)
- val ssc = new StreamingContext(sc, batchDuration)
- val numBatches = 5
- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse)
+ // create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.1)
- .setNumIterations(50)
-
- model.trainOn(data)
-
- ssc.start()
+ .setNumIterations(25)
- // write data to a file stream
- val history = new ArrayBuffer[Double](numBatches)
- for (i <- 0 until numBatches) {
- val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
- val file = new File(testDir, i.toString)
- Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
- Thread.sleep(batchDuration.milliseconds)
- // wait an extra few seconds to make sure the update finishes before new data arrive
- Thread.sleep(4000)
- history.append(math.abs(model.latestModel().weights(0) - 10.0))
+ // generate sequence of simulated data
+ val numBatches = 10
+ val input = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
}
- ssc.stop(stopSparkContext=false)
+ // create buffer to store intermediate fits
+ val history = new ArrayBuffer[Double](numBatches)
- System.clearProperty("spark.driver.port")
- Utils.deleteRecursively(testDir)
+ // apply model training to input stream, storing the intermediate results
+ // (we add a count to ensure the result is a DStream)
+ val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+ // compute change in error
val deltas = history.drop(1).zip(history.dropRight(1))
// check error stability (it always either shrinks, or increases with small tol)
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
+ // Test predictions on a stream
+ test("predictions") {
+
+ // create model initialized with true weights
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(10.0, 10.0))
+ .setStepSize(0.1)
+ .setNumIterations(25)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
+ }
+
+ // apply model predictions to test stream
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+ // collect the output as (true, estimated) tuples
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // compute the mean absolute error and check that it's always less than 0.1
+ val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
+ assert(errors.forall(x => x <= 0.1))
+
+ }
+
}