aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py61
1 files changed, 55 insertions, 6 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 2bbb9c3fca..5ec1a8084d 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -73,11 +73,36 @@ class LogisticRegressionModel(LinearModel):
class LogisticRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
- """Train a logistic regression model on the given data."""
+ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
+ initialWeights=None, regParam=1.0, regType=None, intercept=False):
+ """
+ Train a logistic regression model on the given data.
+
+ @param data: The training data.
+ @param iterations: The number of iterations (default: 100).
+ @param step: The step parameter used in SGD
+ (default: 1.0).
+ @param miniBatchFraction: Fraction of data to be used for each SGD
+ iteration.
+ @param initialWeights: The initial weights (default: None).
+ @param regParam: The regularizer parameter (default: 1.0).
+ @param regType: The type of regularizer used for training
+ our model.
+ Allowed values: "l1" for using L1Updater,
+ "l2" for using
+ SquaredL2Updater,
+ "none" for no regularizer.
+ (default: "none")
+ @param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ """
sc = data.context
+ if regType is None:
+ regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
- d._jrdd, iterations, step, miniBatchFraction, i)
+ d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
initialWeights)
@@ -115,11 +140,35 @@ class SVMModel(LinearModel):
class SVMWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
- miniBatchFraction=1.0, initialWeights=None):
- """Train a support vector machine on the given data."""
+ miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
+ """
+ Train a support vector machine on the given data.
+
+ @param data: The training data.
+ @param iterations: The number of iterations (default: 100).
+ @param step: The step parameter used in SGD
+ (default: 1.0).
+ @param regParam: The regularizer parameter (default: 1.0).
+ @param miniBatchFraction: Fraction of data to be used for each SGD
+ iteration.
+ @param initialWeights: The initial weights (default: None).
+ @param regType: The type of regularizer used for training
+ our model.
+ Allowed values: "l1" for using L1Updater,
+ "l2" for using
+ SquaredL2Updater,
+ "none" for no regularizer.
+ (default: "none")
+ @param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ """
sc = data.context
+ if regType is None:
+ regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
- d._jrdd, iterations, step, regParam, miniBatchFraction, i)
+ d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)