diff options
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r-- | python/pyspark/mllib/regression.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index a3a68b29e0..e90b72893f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -47,14 +47,15 @@ class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initial_weights=array([1.0])) """ class LinearRegressionWithSGD(object): @classmethod - def train(cls, sc, data, iterations=100, step=1.0, + def train(cls, data, iterations=100, step=1.0, mini_batch_fraction=1.0, initial_weights=None): """Train a linear regression model on the given data.""" + sc = data.context return _regression_train_wrapper(sc, lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( d._jrdd, iterations, step, mini_batch_fraction, i), @@ -65,14 +66,15 @@ class LassoModel(LinearRegressionModelBase): l_1 penalty term. >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + >>> lrm = LassoWithSGD.train(sc.parallelize(data), initial_weights=array([1.0])) """ - + class LassoWithSGD(object): @classmethod - def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + def train(cls, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): """Train a Lasso regression model on the given data.""" + sc = data.context return _regression_train_wrapper(sc, lambda d, i: sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd, iterations, step, reg_param, mini_batch_fraction, i), @@ -83,14 +85,15 @@ class RidgeRegressionModel(LinearRegressionModelBase): l_2 penalty term. >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) - >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0])) + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initial_weights=array([1.0])) """ class RidgeRegressionWithSGD(object): @classmethod - def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0, + def train(cls, data, iterations=100, step=1.0, reg_param=1.0, mini_batch_fraction=1.0, initial_weights=None): """Train a ridge regression model on the given data.""" + sc = data.context return _regression_train_wrapper(sc, lambda d, i: sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd, iterations, step, reg_param, mini_batch_fraction, i), |