aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 15:57:33 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 15:57:48 -0800
commit4ae78abe66e593ac8bf9de37eca80413730c431b (patch)
tree78d007628b68b4f0a83281772ee1b4ea15d4f13d /python
parenta93d64c8c677f7121599b21883e1671e1226ec0b (diff)
downloadspark-4ae78abe66e593ac8bf9de37eca80413730c431b.tar.gz
spark-4ae78abe66e593ac8bf9de37eca80413730c431b.tar.bz2
spark-4ae78abe66e593ac8bf9de37eca80413730c431b.zip
[SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS
``` class LogisticRegressionWithLBFGS | train(cls, data, iterations=100, initialWeights=None, corrections=10, tolerance=0.0001, regParam=0.01, intercept=False) | Train a logistic regression model on the given data. | | :param data: The training data, an RDD of LabeledPoint. | :param iterations: The number of iterations (default: 100). | :param initialWeights: The initial weights (default: None). | :param regParam: The regularizer parameter (default: 0.01). | :param regType: The type of regularizer used for training | our model. | :Allowed values: | - "l1" for using L1 regularization | - "l2" for using L2 regularization | - None for no regularization | (default: "l2") | :param intercept: Boolean parameter which indicates the use | or not of the augmented representation for | training data (i.e. whether bias features | are activated or not). | :param corrections: The number of corrections used in the LBFGS update (default: 10). | :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). | | >>> data = [ | ... LabeledPoint(0.0, [0.0, 1.0]), | ... LabeledPoint(1.0, [1.0, 0.0]), | ... ] | >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) | >>> lrm.predict([1.0, 0.0]) | 1 | >>> lrm.predict([0.0, 1.0]) | 0 | >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect() | [1, 0] ``` Author: Davies Liu <davies@databricks.com> Closes #3307 from davies/lbfgs and squashes the following commits: 34bd986 [Davies Liu] Merge branch 'master' of http://git-wip-us.apache.org/repos/asf/spark into lbfgs 5a945a6 [Davies Liu] address comments 941061b [Davies Liu] Merge branch 'master' of github.com:apache/spark into lbfgs 03e5543 [Davies Liu] add it to docs ed2f9a8 [Davies Liu] add regType 76cd1b6 [Davies Liu] reorder arguments 4429a74 [Davies Liu] Update classification.py 9252783 [Davies Liu] python api for LogisticRegressionWithLBFGS (cherry picked from commit d2e29516f2064f93f3a9070c91fc7460706e0b0a) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py57
1 files changed, 53 insertions, 4 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index ee0729b1eb..f14d0ed11c 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -26,8 +26,8 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
-__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
- 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
+ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
class LinearBinaryClassificationModel(LinearModel):
@@ -151,7 +151,7 @@ class LogisticRegressionWithSGD(object):
(default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
@@ -164,6 +164,55 @@ class LogisticRegressionWithSGD(object):
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
+class LogisticRegressionWithLBFGS(object):
+
+ @classmethod
+ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
+ intercept=False, corrections=10, tolerance=1e-4):
+ """
+ Train a logistic regression model on the given data.
+
+ :param data: The training data, an RDD of LabeledPoint.
+ :param iterations: The number of iterations (default: 100).
+ :param initialWeights: The initial weights (default: None).
+ :param regParam: The regularizer parameter (default: 0.01).
+ :param regType: The type of regularizer used for training
+ our model.
+
+ :Allowed values:
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
+
+ (default: "l2")
+
+ :param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ :param corrections: The number of corrections used in the LBFGS
+ update (default: 10).
+ :param tolerance: The convergence tolerance of iterations for
+ L-BFGS (default: 1e-4).
+
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
+ ... ]
+ >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ """
+ def train(rdd, i):
+ return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
+ float(regParam), str(regType), bool(intercept), int(corrections),
+ float(tolerance))
+
+ return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
+
+
class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
@@ -241,7 +290,7 @@ class SVMWithSGD(object):
(default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).