aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index a3a68b29e0..e90b72893f 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -47,14 +47,15 @@ class LinearRegressionModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initial_weights=array([1.0]))
"""
class LinearRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0,
+ def train(cls, data, iterations=100, step=1.0,
mini_batch_fraction=1.0, initial_weights=None):
"""Train a linear regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
d._jrdd, iterations, step, mini_batch_fraction, i),
@@ -65,14 +66,15 @@ class LassoModel(LinearRegressionModelBase):
l_1 penalty term.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = LassoWithSGD.train(sc.parallelize(data), initial_weights=array([1.0]))
"""
-
+
class LassoWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ def train(cls, data, iterations=100, step=1.0, reg_param=1.0,
mini_batch_fraction=1.0, initial_weights=None):
"""Train a Lasso regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
iterations, step, reg_param, mini_batch_fraction, i),
@@ -83,14 +85,15 @@ class RidgeRegressionModel(LinearRegressionModelBase):
l_2 penalty term.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initial_weights=array([1.0]))
"""
class RidgeRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ def train(cls, data, iterations=100, step=1.0, reg_param=1.0,
mini_batch_fraction=1.0, initial_weights=None):
"""Train a ridge regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
iterations, step, reg_param, mini_batch_fraction, i),