aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMichael Giannakopoulos <miccagiann@gmail.com>2014-08-01 21:00:31 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-01 21:00:31 -0700
commitc281189222e645d2c87277c269e2102c3c8ccc95 (patch)
treee56b1d46896433d8c859b68870807e0faa0cbd64 /python
parentf6a1899306c5ad766fea122d3ab4b83436d9f6fd (diff)
downloadspark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.gz
spark-c281189222e645d2c87277c269e2102c3c8ccc95.tar.bz2
spark-c281189222e645d2c87277c269e2102c3c8ccc95.zip
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods.
Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC). Author: Michael Giannakopoulos <miccagiann@gmail.com> Closes #1624 from miccagiann/new-branch and squashes the following commits: c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master. 8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets. fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException. 44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD. 8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one. 638be47 [Michael Giannakopoulos] Modified code to comply with code standards. ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter. 78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function. 3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/regression.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index b84bc531de..041b119269 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase):
class LinearRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0,
- miniBatchFraction=1.0, initialWeights=None):
- """Train a linear regression model on the given data."""
+ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
+ initialWeights=None, regParam=1.0, regType=None, intercept=False):
+ """
+ Train a linear regression model on the given data.
+
+ @param data: The training data.
+ @param iterations: The number of iterations (default: 100).
+ @param step: The step parameter used in SGD
+ (default: 1.0).
+ @param miniBatchFraction: Fraction of data to be used for each SGD
+ iteration.
+ @param initialWeights: The initial weights (default: None).
+ @param regParam: The regularizer parameter (default: 1.0).
+ @param regType: The type of regularizer used for training
+ our model.
+ Allowed values: "l1" for using L1Updater,
+ "l2" for using
+ SquaredL2Updater,
+ "none" for no regularizer.
+ (default: "none")
+ @param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ """
sc = data.context
+ if regType is None:
+ regType = "none"
train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- d._jrdd, iterations, step, miniBatchFraction, i)
+ d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights)