aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala28
5 files changed, 62 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
index b89f38cf5a..7d33df3221 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
@@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
protected val algorithm = new LogisticRegressionWithSGD(
stepSize, numIterations, regParam, miniBatchFraction)
+ protected var model: Option[LogisticRegressionModel] = None
+
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index ce95c063db..cea8f3f473 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
- protected var model: Option[M] = None
+ protected var model: Option[M]
/** The algorithm to use for updating. */
protected val algorithm: A
@@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction.")
}
- data.map(model.get.predict)
+ data.map{x => model.get.predict(x)}
}
/** Java-friendly version of `predictOn`. */
@@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction")
}
- data.mapValues(model.get.predict)
+ data.mapValues{x => model.get.predict(x)}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index e5e6301127..a49153bf73 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] (
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
+ protected var model: Option[LinearRegressionModel] = None
+
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 8b3e6e5ce9..d50c43d439 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
assert(errors.forall(x => x <= 0.4))
}
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert(error.head > 0.8 & error.last < 0.2)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 70b43ddb7d..24fd8df691 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.1))
}
+
+ // Test training combined with prediction
+ test("training and prediction") {
+ // create model initialized with zero weights
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+
+ // generate sequence of simulated data for testing
+ val numBatches = 10
+ val nPoints = 100
+ val testInput = (0 until numBatches).map { i =>
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
+ }
+
+ // train and predict
+ val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ })
+
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+
+ // assert that prediction error improves, ensuring that the updated model is being used
+ val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
+ assert((error.head - error.last) > 2)
+ }
}