diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-08-18 18:20:54 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-18 18:20:54 -0700 |
commit | 217b5e915e2f21f047dfc4be680cd20d58baf9f8 (patch) | |
tree | da708d2bd7290989bcdd5c884086be2bf7c3096f /mllib/src/main | |
parent | c8b16ca0d86cc60fb960eebf0cb383f159a88b03 (diff) | |
download | spark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.tar.gz spark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.tar.bz2 spark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.zip |
[SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn
It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab
Author: Xiangrui Meng <meng@databricks.com>
Closes #2023 from mengxr/predict-on-values and squashes the following commits:
cac47b8 [Xiangrui Meng] add classtag
2821b3b [Xiangrui Meng] use mapValues
0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b8b0b42611..8db0442a7a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[ /** * Use the model to make predictions on batches of data from a DStream * - * @param data DStream containing labeled data + * @param data DStream containing feature vectors * @return DStream containing predictions */ - def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + def predictOn(data: DStream[Vector]): DStream[Double] = { if (Option(model.weights) == None) { - logError("Initial weights must be set before starting prediction") - throw new IllegalArgumentException + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) } - data.map(x => model.predict(x.features)) + data.map(model.predict) } + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * @param data DStream containing feature vectors + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { + if (Option(model.weights) == None) { + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) + } + data.mapValues(model.predict) + } } |