aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfreeman <the.freeman.lab@gmail.com>2015-04-02 21:37:44 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-02 21:38:19 -0700
commit6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302 (patch)
tree9d5679a8d4bcb97da6fe030de7961b76fc1422b3
parent8a0aa81ca37d337423db60edb09cf264cc2c6498 (diff)
downloadspark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.tar.gz
spark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.tar.bz2
spark-6e1c1ec67bc4d7e5700f523ec08db6bb25bd2302.zip
[SPARK-6345][STREAMING][MLLIB] Fix for training with prediction
This patch fixes a reported bug causing model updates to not properly propagate to model predictions during streaming regression. These minor changes in model declaration fix the problem, and I expanded the tests to include the scenario in which the bug was arising. The two new tests failed prior to the patch and now pass. cc mengxr Author: freeman <the.freeman.lab@gmail.com> Closes #5037 from freeman-lab/train-predict-fix and squashes the following commits: 3af953e [freeman] Expand test coverage to include combined training and prediction 8f84fc8 [freeman] Move model declaration
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala28
5 files changed, 62 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
index b89f38cf5a..7d33df3221 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
@@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
protected val algorithm = new LogisticRegressionWithSGD(
stepSize, numIterations, regParam, miniBatchFraction)
+ protected var model: Option[LogisticRegressionModel] = None
+
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index ce95c063db..cea8f3f473 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
- protected var model: Option[M] = None
+ protected var model: Option[M]
/** The algorithm to use for updating. */
protected val algorithm: A
@@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction.")
}
- data.map(model.get.predict)
+ data.map{x => model.get.predict(x)}
}
/** Java-friendly version of `predictOn`. */
@@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction")
}
- data.mapValues(model.get.predict)
+ data.mapValues{x => model.get.predict(x)}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index e5e6301127..a49153bf73 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] (
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
+ protected var model: Option[LinearRegressionModel] = None
+
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 8b3e6e5ce9..d50c43d439 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
assert(errors.forall(x => x <= 0.4))
}
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert(error.head > 0.8 & error.last < 0.2)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 70b43ddb7d..24fd8df691 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.1))
}
+
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert((error.head - error.last) > 2)
+ }
}