aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-18 18:20:54 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-18 18:20:54 -0700
commit217b5e915e2f21f047dfc4be680cd20d58baf9f8 (patch)
treeda708d2bd7290989bcdd5c884086be2bf7c3096f /mllib/src
parentc8b16ca0d86cc60fb960eebf0cb383f159a88b03 (diff)
downloadspark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.tar.gz
spark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.tar.bz2
spark-217b5e915e2f21f047dfc4be680cd20d58baf9f8.zip
[SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn
It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab Author: Xiangrui Meng <meng@databricks.com> Closes #2023 from mengxr/predict-on-values and squashes the following commits: cac47b8 [Xiangrui Meng] add classtag 2821b3b [Xiangrui Meng] use mapValues 0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala31
1 files changed, 25 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index b8b0b42611..8db0442a7a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -17,8 +17,12 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.annotation.DeveloperApi
+import scala.reflect.ClassTag
+
import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.DStream
/**
@@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[
/**
* Use the model to make predictions on batches of data from a DStream
*
- * @param data DStream containing labeled data
+ * @param data DStream containing feature vectors
* @return DStream containing predictions
*/
- def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
+ def predictOn(data: DStream[Vector]): DStream[Double] = {
if (Option(model.weights) == None) {
- logError("Initial weights must be set before starting prediction")
- throw new IllegalArgumentException
+ val msg = "Initial weights must be set before starting prediction"
+ logError(msg)
+ throw new IllegalArgumentException(msg)
}
- data.map(x => model.predict(x.features))
+ data.map(model.predict)
}
+ /**
+ * Use the model to make predictions on the values of a DStream and carry over its keys.
+ * @param data DStream containing feature vectors
+ * @tparam K key type
+ * @return DStream containing the input keys and the predictions as values
+ */
+ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
+ if (Option(model.weights) == None) {
+ val msg = "Initial weights must be set before starting prediction"
+ logError(msg)
+ throw new IllegalArgumentException(msg)
+ }
+ data.mapValues(model.predict)
+ }
}