diff options
author | Davies Liu <davies@databricks.com> | 2014-11-18 15:57:33 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-18 15:57:33 -0800 |
commit | d2e29516f2064f93f3a9070c91fc7460706e0b0a (patch) | |
tree | 9a7758ecd722e59f3a6662f20c7563d6ce60636a /mllib | |
parent | 010bc86e40a0e54b6850b75abd6105e70eb1af10 (diff) | |
download | spark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.tar.gz spark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.tar.bz2 spark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.zip |
[SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS
```
class LogisticRegressionWithLBFGS
| train(cls, data, iterations=100, initialWeights=None, corrections=10, tolerance=0.0001, regParam=0.01, intercept=False)
| Train a logistic regression model on the given data.
|
| :param data: The training data, an RDD of LabeledPoint.
| :param iterations: The number of iterations (default: 100).
| :param initialWeights: The initial weights (default: None).
| :param regParam: The regularizer parameter (default: 0.01).
| :param regType: The type of regularizer used for training
| our model.
| :Allowed values:
| - "l1" for using L1 regularization
| - "l2" for using L2 regularization
| - None for no regularization
| (default: "l2")
| :param intercept: Boolean parameter which indicates the use
| or not of the augmented representation for
| training data (i.e. whether bias features
| are activated or not).
| :param corrections: The number of corrections used in the LBFGS update (default: 10).
| :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4).
|
| >>> data = [
| ... LabeledPoint(0.0, [0.0, 1.0]),
| ... LabeledPoint(1.0, [1.0, 0.0]),
| ... ]
| >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
| >>> lrm.predict([1.0, 0.0])
| 1
| >>> lrm.predict([0.0, 1.0])
| 0
| >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
| [1, 0]
```
Author: Davies Liu <davies@databricks.com>
Closes #3307 from davies/lbfgs and squashes the following commits:
34bd986 [Davies Liu] Merge branch 'master' of http://git-wip-us.apache.org/repos/asf/spark into lbfgs
5a945a6 [Davies Liu] address comments
941061b [Davies Liu] Merge branch 'master' of github.com:apache/spark into lbfgs
03e5543 [Davies Liu] add it to docs
ed2f9a8 [Davies Liu] add regType
76cd1b6 [Davies Liu] reorder arguments
4429a74 [Davies Liu] Update classification.py
9252783 [Davies Liu] python api for LogisticRegressionWithLBFGS
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index c8476a5370..6f94b7f483 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -230,6 +230,41 @@ class PythonMLLibAPI extends Serializable { } /** + * Java stub for Python mllib LogisticRegressionWithLBFGS.train() + */ + def trainLogisticRegressionModelWithLBFGS( + data: JavaRDD[LabeledPoint], + numIterations: Int, + initialWeights: Vector, + regParam: Double, + regType: String, + intercept: Boolean, + corrections: Int, + tolerance: Double): JList[Object] = { + val LogRegAlg = new LogisticRegressionWithLBFGS() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setNumCorrections(corrections) + .setConvergenceTol(tolerance) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType == null) { + LogRegAlg.optimizer.setUpdater(new SimpleUpdater) + } else { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } + trainRegressionModel( + LogRegAlg, + data, + initialWeights) + } + + /** * Java stub for NaiveBayes.train() */ def trainNaiveBayes( |