aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/regression.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 8e90adee5f..5b7afc15dd 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel):
def predict(self, x):
"""
- Predict the value of the dependent variable given a vector x
- containing values for the independent variables.
+ Predict the value of the dependent variable given a vector or
+ an RDD of vectors containing values for the independent variables.
"""
+ if isinstance(x, RDD):
+ return x.map(self.predict)
x = _convert_to_vector(x)
return self.weights.dot(x) + self.intercept
@@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+ True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
@@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+ True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
@@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+ True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)