aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-03 00:14:43 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-03 00:14:43 -0800
commit659329f9ee51ca8ae6232e07c45b5d9144d49667 (patch)
tree5d28feb5ea73e88d9aef547d14f45b4db3bd10ef /mllib
parent980764f3c0c065cc32454a036e8d0ead5a92037b (diff)
downloadspark-659329f9ee51ca8ae6232e07c45b5d9144d49667.tar.gz
spark-659329f9ee51ca8ae6232e07c45b5d9144d49667.tar.bz2
spark-659329f9ee51ca8ae6232e07c45b5d9144d49667.zip
[minor] update streaming linear algorithms
Author: Xiangrui Meng <meng@databricks.com> Closes #4329 from mengxr/streaming-lr and squashes the following commits: 78731e1 [Xiangrui Meng] update streaming linear algorithms
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala41
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala2
3 files changed, 24 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
index eabd2162e2..6a3893d0e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala
@@ -88,8 +88,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
- this.model = Option(algorithm.createModel(initialWeights, 0.0))
+ this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index 39a0dee931..44a8dbb994 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -21,7 +21,7 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.streaming.dstream.DStream
/**
@@ -58,7 +58,7 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
- protected var model: Option[M] = null
+ protected var model: Option[M] = None
/** The algorithm to use for updating. */
protected val algorithm: A
@@ -77,18 +77,25 @@ abstract class StreamingLinearAlgorithm[
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
- if (Option(model) == None) {
- logError("Model must be initialized before starting training")
- throw new IllegalArgumentException
+ if (model.isEmpty) {
+ throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
- model = Option(algorithm.run(rdd, model.get.weights))
- logInfo("Model updated at time %s".format(time.toString))
- val display = model.get.weights.size match {
- case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
- case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ val initialWeights =
+ model match {
+ case Some(m) =>
+ m.weights
+ case None =>
+ val numFeatures = rdd.first().features.size
+ Vectors.dense(numFeatures)
}
- logInfo("Current model: weights, %s".format (display))
+ model = Some(algorithm.run(rdd, initialWeights))
+ logInfo("Model updated at time %s".format(time.toString))
+ val display = model.get.weights.size match {
+ case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+ case _ => model.get.weights.toArray.mkString("[", ",", "]")
+ }
+ logInfo("Current model: weights, %s".format (display))
}
}
@@ -99,10 +106,8 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing predictions
*/
def predictOn(data: DStream[Vector]): DStream[Double] = {
- if (Option(model) == None) {
- val msg = "Model must be initialized before starting prediction"
- logError(msg)
- throw new IllegalArgumentException(msg)
+ if (model.isEmpty) {
+ throw new IllegalArgumentException("Model must be initialized before starting prediction.")
}
data.map(model.get.predict)
}
@@ -114,10 +119,8 @@ abstract class StreamingLinearAlgorithm[
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
- if (Option(model) == None) {
- val msg = "Model must be initialized before starting prediction"
- logError(msg)
- throw new IllegalArgumentException(msg)
+ if (model.isEmpty) {
+ throw new IllegalArgumentException("Model must be initialized before starting prediction")
}
data.mapValues(model.get.predict)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index c0625b4880..e5e6301127 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (
/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
- this.model = Option(algorithm.createModel(initialWeights, 0.0))
+ this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
}