diff options
author | Michael Giannakopoulos <miccagiann@gmail.com> | 2014-08-01 21:00:31 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-08-01 21:00:31 -0700 |
commit | c281189222e645d2c87277c269e2102c3c8ccc95 (patch) | |
tree | e56b1d46896433d8c859b68870807e0faa0cbd64 /python | |
parent | f6a1899306c5ad766fea122d3ab4b83436d9f6fd (diff) | |
download | spark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.gz spark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.bz2 spark-c281189222e645d2c87277c269e2102c3c8ccc95.zip |
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods.
Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC).
Author: Michael Giannakopoulos <miccagiann@gmail.com>
Closes #1624 from miccagiann/new-branch and squashes the following commits:
c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master.
8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets.
fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException.
44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD.
8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one.
638be47 [Michael Giannakopoulos] Modified code to comply with code standards.
ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file
b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter.
78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function.
3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/regression.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index b84bc531de..041b119269 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a linear regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a linear regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) |