aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/pom.xml7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala121
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala4
3 files changed, 77 insertions, 55 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fc1ecfbea7..c7a1e2ae75 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -91,6 +91,13 @@
<artifactId>junit-interface</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<profiles>
<profile>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 45e25eecf5..28489410f8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -17,20 +17,19 @@
package org.apache.spark.mllib.regression
-import java.io.File
-import java.nio.charset.Charset
-
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.TestSuiteBase
+
+class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
-class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
+ // use longer wait time to ensure job completion
+ override def maxWaitTimeMillis = 20000
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
- test("streaming linear regression parameter accuracy") {
+ test("parameter accuracy") {
- val testDir = Files.createTempDir()
- val numBatches = 10
- val batchDuration = Milliseconds(1000)
- val ssc = new StreamingContext(sc, batchDuration)
- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse)
+ // create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.1)
- .setNumIterations(50)
+ .setNumIterations(25)
- model.trainOn(data)
-
- ssc.start()
-
- // write data to a file stream
- for (i <- 0 until numBatches) {
- val samples = LinearDataGenerator.generateLinearInput(
- 0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
- val file = new File(testDir, i.toString)
- Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
- Thread.sleep(batchDuration.milliseconds)
+ // generate sequence of simulated data
+ val numBatches = 10
+ val input = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
}
- ssc.stop(stopSparkContext=false)
-
- System.clearProperty("spark.driver.port")
- Utils.deleteRecursively(testDir)
+ // apply model training to input stream
+ val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
// check accuracy of final parameter estimates
assertEqual(model.latestModel().intercept, 0.0, 0.1)
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
- test("streaming linear regression parameter convergence") {
+ test("parameter convergence") {
- val testDir = Files.createTempDir()
- val batchDuration = Milliseconds(2000)
- val ssc = new StreamingContext(sc, batchDuration)
- val numBatches = 5
- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse)
+ // create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.1)
- .setNumIterations(50)
-
- model.trainOn(data)
-
- ssc.start()
+ .setNumIterations(25)
- // write data to a file stream
- val history = new ArrayBuffer[Double](numBatches)
- for (i <- 0 until numBatches) {
- val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
- val file = new File(testDir, i.toString)
- Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
- Thread.sleep(batchDuration.milliseconds)
- // wait an extra few seconds to make sure the update finishes before new data arrive
- Thread.sleep(4000)
- history.append(math.abs(model.latestModel().weights(0) - 10.0))
+ // generate sequence of simulated data
+ val numBatches = 10
+ val input = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
}
- ssc.stop(stopSparkContext=false)
+ // create buffer to store intermediate fits
+ val history = new ArrayBuffer[Double](numBatches)
- System.clearProperty("spark.driver.port")
- Utils.deleteRecursively(testDir)
+ // apply model training to input stream, storing the intermediate results
+ // (we add a count to ensure the result is a DStream)
+ val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+ // compute change in error
val deltas = history.drop(1).zip(history.dropRight(1))
// check error stability (it always either shrinks, or increases with small tol)
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}
+ // Test predictions on a stream
+ test("predictions") {
+
+ // create model initialized with true weights
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(10.0, 10.0))
+ .setStepSize(0.1)
+ .setNumIterations(25)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
+ }
+
+ // apply model predictions to test stream
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+ // collect the output as (true, estimated) tuples
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // compute the mean absolute error and check that it's always less than 0.1
+ val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
+ assert(errors.forall(x => x <= 0.1))
+
+ }
+
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index cc178fba12..f095da9cb5 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -242,7 +242,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
// Get the output buffer
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ val outputStream = ssc.graph.getOutputStreams.
+ filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]).
+ head.asInstanceOf[TestOutputStreamWithPartitions[V]]
val output = outputStream.output
try {