aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala18
3 files changed, 43 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index aee51bf22d..141052ba81 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[
throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
- model = Some(algorithm.run(rdd, model.get.weights))
- logInfo("Model updated at time %s".format(time.toString))
- val display = model.get.weights.size match {
- case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
- case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ if (!rdd.isEmpty) {
+ model = Some(algorithm.run(rdd, model.get.weights))
+ logInfo(s"Model updated at time ${time.toString}")
+ val display = model.get.weights.size match {
+ case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+ case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ }
+ logInfo(s"Current model: weights, ${display}")
}
- logInfo("Current model: weights, %s".format (display))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index e98b61e13e..fd653296c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert(error.head > 0.8 & error.last < 0.2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+ val numBatches = 10
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 9a379406d5..f5e2d31056 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert((error.head - error.last) > 2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+ val numBatches = 10
+ val nPoints = 100
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}