From 659329f9ee51ca8ae6232e07c45b5d9144d49667 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 3 Feb 2015 00:14:43 -0800 Subject: [minor] update streaming linear algorithms Author: Xiangrui Meng Closes #4329 from mengxr/streaming-lr and squashes the following commits: 78731e1 [Xiangrui Meng] update streaming linear algorithms --- .../StreamingLogisticRegressionWithSGD.scala | 3 +- .../regression/StreamingLinearAlgorithm.scala | 41 ++++++++++++---------- .../StreamingLinearRegressionWithSGD.scala | 2 +- 3 files changed, 24 insertions(+), 22 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index eabd2162e2..6a3893d0e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -88,8 +88,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = Option(algorithm.createModel(initialWeights, 0.0)) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 39a0dee931..44a8dbb994 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.streaming.dstream.DStream /** @@ -58,7 +58,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = null + protected var model: Option[M] = None /** The algorithm to use for updating. */ protected val algorithm: A @@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing labeled data */ def trainOn(data: DStream[LabeledPoint]) { - if (Option(model) == None) { - logError("Model must be initialized before starting training") - throw new IllegalArgumentException + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = Option(algorithm.run(rdd, model.get.weights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + val initialWeights = + model match { + case Some(m) => + m.weights + case None => + val numFeatures = rdd.first().features.size + Vectors.dense(numFeatures) } - logInfo("Current model: weights, %s".format (display)) + model = Some(algorithm.run(rdd, initialWeights)) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) } } @@ -99,10 +106,8 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Double] = { - if (Option(model) == None) { - val msg = "Model must be initialized before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction.") } data.map(model.get.predict) } @@ -114,10 +119,8 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { - if (Option(model) == None) { - val msg = "Model must be initialized before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction") } data.mapValues(model.get.predict) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index c0625b4880..e5e6301127 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = Option(algorithm.createModel(initialWeights, 0.0)) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } -- cgit v1.2.3