aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 15:57:33 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 15:57:33 -0800
commitd2e29516f2064f93f3a9070c91fc7460706e0b0a (patch)
tree9a7758ecd722e59f3a6662f20c7563d6ce60636a /mllib
parent010bc86e40a0e54b6850b75abd6105e70eb1af10 (diff)
downloadspark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.tar.gz
spark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.tar.bz2
spark-d2e29516f2064f93f3a9070c91fc7460706e0b0a.zip
[SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS
``` class LogisticRegressionWithLBFGS | train(cls, data, iterations=100, initialWeights=None, corrections=10, tolerance=0.0001, regParam=0.01, intercept=False) | Train a logistic regression model on the given data. | | :param data: The training data, an RDD of LabeledPoint. | :param iterations: The number of iterations (default: 100). | :param initialWeights: The initial weights (default: None). | :param regParam: The regularizer parameter (default: 0.01). | :param regType: The type of regularizer used for training | our model. | :Allowed values: | - "l1" for using L1 regularization | - "l2" for using L2 regularization | - None for no regularization | (default: "l2") | :param intercept: Boolean parameter which indicates the use | or not of the augmented representation for | training data (i.e. whether bias features | are activated or not). | :param corrections: The number of corrections used in the LBFGS update (default: 10). | :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). | | >>> data = [ | ... LabeledPoint(0.0, [0.0, 1.0]), | ... LabeledPoint(1.0, [1.0, 0.0]), | ... ] | >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) | >>> lrm.predict([1.0, 0.0]) | 1 | >>> lrm.predict([0.0, 1.0]) | 0 | >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect() | [1, 0] ``` Author: Davies Liu <davies@databricks.com> Closes #3307 from davies/lbfgs and squashes the following commits: 34bd986 [Davies Liu] Merge branch 'master' of http://git-wip-us.apache.org/repos/asf/spark into lbfgs 5a945a6 [Davies Liu] address comments 941061b [Davies Liu] Merge branch 'master' of github.com:apache/spark into lbfgs 03e5543 [Davies Liu] add it to docs ed2f9a8 [Davies Liu] add regType 76cd1b6 [Davies Liu] reorder arguments 4429a74 [Davies Liu] Update classification.py 9252783 [Davies Liu] python api for LogisticRegressionWithLBFGS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala35
1 files changed, 35 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c8476a5370..6f94b7f483 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -230,6 +230,41 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib LogisticRegressionWithLBFGS.train()
+ */
+ def trainLogisticRegressionModelWithLBFGS(
+ data: JavaRDD[LabeledPoint],
+ numIterations: Int,
+ initialWeights: Vector,
+ regParam: Double,
+ regType: String,
+ intercept: Boolean,
+ corrections: Int,
+ tolerance: Double): JList[Object] = {
+ val LogRegAlg = new LogisticRegressionWithLBFGS()
+ LogRegAlg.setIntercept(intercept)
+ LogRegAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setNumCorrections(corrections)
+ .setConvergenceTol(tolerance)
+ if (regType == "l2") {
+ LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
+ } else if (regType == "l1") {
+ LogRegAlg.optimizer.setUpdater(new L1Updater)
+ } else if (regType == null) {
+ LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+ }
+ trainRegressionModel(
+ LogRegAlg,
+ data,
+ initialWeights)
+ }
+
+ /**
* Java stub for NaiveBayes.train()
*/
def trainNaiveBayes(