aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2015-04-02 21:37:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-02 21:38:19 -0700
commit6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302 (patch)
tree9d5679a8d4bcb97da6fe030de7961b76fc1422b3 /mllib/src/test
parent8a0aa81ca37d337423db60edb09cf264cc2c6498 (diff)
downloadspark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.tar.gz
spark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.tar.bz2
spark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.zip
[SPARK-6345][STREAMING][MLLIB] Fix for training with prediction
This patch fixes a reported bug causing model updates to not properly propagate to model predictions during streaming regression. These minor changes in model declaration fix the problem, and I expanded the tests to include the scenario in which the bug was arising. The two new tests failed prior to the patch and now pass. cc mengxr Author: freeman <the.freeman.lab@gmail.com> Closes #5037 from freeman-lab/train-predict-fix and squashes the following commits: 3af953e [freeman] Expand test coverage to include combined training and prediction 8f84fc8 [freeman] Move model declaration
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala28
2 files changed, 55 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 8b3e6e5ce9..d50c43d439 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
assert(errors.forall(x => x <= 0.4))
}
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert(error.head > 0.8 & error.last < 0.2)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 70b43ddb7d..24fd8df691 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.1))
}
+
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert((error.head - error.last) > 2)
+ }
}