aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorPaavo <pparkkin@gmail.com>2015-06-10 23:17:42 +0100
committerSean Owen <sowen@cloudera.com>2015-06-10 23:17:42 +0100
commitb928f543845ddd39e914a0e8f0b0205fd86100c5 (patch)
tree6176fbb68e99ab03e5d4f057a07779c7e20d439a /mllib
parent96a7c888d806adfdb2c722025a1079ed7eaa2052 (diff)
downloadspark-b928f543845ddd39e914a0e8f0b0205fd86100c5.tar.gz
spark-b928f543845ddd39e914a0e8f0b0205fd86100c5.tar.bz2
spark-b928f543845ddd39e914a0e8f0b0205fd86100c5.zip
[SPARK-8200] [MLLIB] Check for empty RDDs in StreamingLinearAlgorithm
Test cases for both StreamingLinearRegression and StreamingLogisticRegression, and code fix. Edit: This contribution is my original work and I license the work to the project under the project's open source license. Author: Paavo <pparkkin@gmail.com> Closes #6713 from pparkkin/streamingmodel-empty-rdd and squashes the following commits: ff5cd78 [Paavo] Update strings to use interpolation. db234cf [Paavo] Use !rdd.isEmpty. 54ad89e [Paavo] Test case for empty stream. 393e36f [Paavo] Ignore empty RDDs. 0bfc365 [Paavo] Test case for empty stream.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala18
3 files changed, 43 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index aee51bf22d..141052ba81 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[
throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
- model = Some(algorithm.run(rdd, model.get.weights))
- logInfo("Model updated at time %s".format(time.toString))
- val display = model.get.weights.size match {
- case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
- case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ if (!rdd.isEmpty) {
+ model = Some(algorithm.run(rdd, model.get.weights))
+ logInfo(s"Model updated at time ${time.toString}")
+ val display = model.get.weights.size match {
+ case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+ case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ }
+ logInfo(s"Current model: weights, ${display}")
}
- logInfo("Current model: weights, %s".format (display))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index e98b61e13e..fd653296c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert(error.head > 0.8 & error.last < 0.2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLogisticRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(-0.1))
+ .setStepSize(0.01)
+ .setNumIterations(10)
+ val numBatches = 10
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 9a379406d5..f5e2d31056 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert((error.head - error.last) > 2)
}
+
+ // Test empty RDDs in a stream
+ test("handling empty RDDs in a stream") {
+ val model = new StreamingLinearRegressionWithSGD()
+ .setInitialWeights(Vectors.dense(0.0, 0.0))
+ .setStepSize(0.2)
+ .setNumIterations(25)
+ val numBatches = 10
+ val nPoints = 100
+ val emptyInput = Seq.empty[Seq[LabeledPoint]]
+ val ssc = setupStreams(emptyInput,
+ (inputDStream: DStream[LabeledPoint]) => {
+ model.trainOn(inputDStream)
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+ }
+ )
+ val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+ }
}