From 52de3acca4ce8c36fd4c9ce162473a091701bbc7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 23 Jul 2015 18:53:07 -0700 Subject: [SPARK-9122] [MLLIB] [PySpark] spark.mllib regression support batch predict spark.mllib support batch predict for LinearRegressionModel, RidgeRegressionModel and LassoModel. Author: Yanbo Liang Closes #7614 from yanboliang/spark-9122 and squashes the following commits: 4e610c0 [Yanbo Liang] spark.mllib regression support batch predict --- python/pyspark/mllib/regression.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 8e90adee5f..5b7afc15dd 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel): def predict(self, x): """ - Predict the value of the dependent variable given a vector x - containing values for the independent variables. + Predict the value of the dependent variable given a vector or + an RDD of vectors containing values for the independent variables. """ + if isinstance(x, RDD): + return x.map(self.predict) x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) -- cgit v1.2.3