aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py24
1 files changed, 10 insertions, 14 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 266b31d3fa..bc7de6d2e8 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -113,10 +113,9 @@ class LinearRegressionWithSGD(object):
miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data."""
sc = data.context
- return _regression_train_wrapper(sc, lambda d, i:
- sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- d._jrdd, iterations, step, miniBatchFraction, i),
- LinearRegressionModel, data, initialWeights)
+ train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
+ d._jrdd, iterations, step, miniBatchFraction, i)
+ return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -157,10 +156,9 @@ class LassoWithSGD(object):
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
sc = data.context
- return _regression_train_wrapper(sc, lambda d, i:
- sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
- iterations, step, regParam, miniBatchFraction, i),
- LassoModel, data, initialWeights)
+ train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(
+ d._jrdd, iterations, step, regParam, miniBatchFraction, i)
+ return _regression_train_wrapper(sc, train_f, LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
@@ -201,18 +199,16 @@ class RidgeRegressionWithSGD(object):
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
sc = data.context
- return _regression_train_wrapper(sc, lambda d, i:
- sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
- iterations, step, regParam, miniBatchFraction, i),
- RidgeRegressionModel, data, initialWeights)
+ train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(
+ d._jrdd, iterations, step, regParam, miniBatchFraction, i)
+ return _regression_train_wrapper(sc, train_func, RidgeRegressionModel, data, initialWeights)
def _test():
import doctest
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs,
- optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)