aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-15 21:04:29 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-15 21:04:29 -0700
commit5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd (patch)
treed2fb3a33cee986c5d7d7c54c88fdbe4d53d98c92 /mllib
parentcc3648774e9a744850107bb187f2828d447e0a48 (diff)
downloadspark-5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd.tar.gz
spark-5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd.tar.bz2
spark-5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd.zip
[SPARK-3078][MLLIB] Make LRWithLBFGS API consistent with others
Should ask users to set parameters through the optimizer. dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #1973 from mengxr/lr-lbfgs and squashes the following commits: e3efbb1 [Xiangrui Meng] fix tests 21b3579 [Xiangrui Meng] fix method name 641eea4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lr-lbfgs 456ab7c [Xiangrui Meng] update LRWithLBFGS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala24
4 files changed, 29 insertions, 49 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 6790c86f65..486bdbfa9c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -73,6 +73,8 @@ class LogisticRegressionModel (
/**
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
* NOTE: Labels used in Logistic Regression should be {0, 1}
+ *
+ * Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
private var stepSize: Double,
@@ -191,51 +193,19 @@ object LogisticRegressionWithSGD {
/**
* Train a classification model for Logistic Regression using Limited-memory BFGS.
+ * Standard feature scaling and L2 regularization are used by default.
* NOTE: Labels used in Logistic Regression should be {0, 1}
*/
-class LogisticRegressionWithLBFGS private (
- private var convergenceTol: Double,
- private var maxNumIterations: Int,
- private var regParam: Double)
+class LogisticRegressionWithLBFGS
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
- /**
- * Construct a LogisticRegression object with default parameters
- */
- def this() = this(1E-4, 100, 0.0)
-
this.setFeatureScaling(true)
- private val gradient = new LogisticGradient()
- private val updater = new SimpleUpdater()
- // Have to return new LBFGS object every time since users can reset the parameters anytime.
- override def optimizer = new LBFGS(gradient, updater)
- .setNumCorrections(10)
- .setConvergenceTol(convergenceTol)
- .setMaxNumIterations(maxNumIterations)
- .setRegParam(regParam)
+ override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
override protected val validators = List(DataValidators.binaryLabelValidator)
- /**
- * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
- * Smaller value will lead to higher accuracy with the cost of more iterations.
- */
- def setConvergenceTol(convergenceTol: Double): this.type = {
- this.convergenceTol = convergenceTol
- this
- }
-
- /**
- * Set the maximal number of iterations for L-BFGS. Default 100.
- */
- def setNumIterations(numIterations: Int): this.type = {
- this.maxNumIterations = numIterations
- this
- }
-
override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 033fe44f34..d16d0daf08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
/**
* Set the maximal number of iterations for L-BFGS. Default 100.
+ * @deprecated use [[LBFGS#setNumIterations]] instead
*/
+ @deprecated("use setNumIterations instead", "1.1.0")
def setMaxNumIterations(iters: Int): this.type = {
+ this.setNumIterations(iters)
+ }
+
+ /**
+ * Set the maximal number of iterations for L-BFGS. Default 100.
+ */
+ def setNumIterations(iters: Int): this.type = {
this.maxNumIterations = iters
this
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index bc05b20468..862178694a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -272,8 +272,9 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model =
- (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points)
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+ lr.optimizer.setNumIterations(2)
+ val model = lr.run(points)
val predictions = model.predict(points.map(_.features))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 5f4c24115a..ccba004baa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -55,7 +55,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
val convergenceTol = 1e-12
- val maxNumIterations = 10
+ val numIterations = 10
val (_, loss) = LBFGS.runLBFGS(
dataRDD,
@@ -63,7 +63,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
simpleUpdater,
numCorrections,
convergenceTol,
- maxNumIterations,
+ numIterations,
regParam,
initialWeightsWithIntercept)
@@ -99,7 +99,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
// Prepare another non-zero weights to compare the loss in the first iteration.
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
val convergenceTol = 1e-12
- val maxNumIterations = 10
+ val numIterations = 10
val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
dataRDD,
@@ -107,7 +107,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
squaredL2Updater,
numCorrections,
convergenceTol,
- maxNumIterations,
+ numIterations,
regParam,
initialWeightsWithIntercept)
@@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
/**
* For the first run, we set the convergenceTol to 0.0, so that the algorithm will
- * run up to the maxNumIterations which is 8 here.
+ * run up to the numIterations which is 8 here.
*/
val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0)
- val maxNumIterations = 8
+ val numIterations = 8
var convergenceTol = 0.0
val (_, lossLBFGS1) = LBFGS.runLBFGS(
@@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
squaredL2Updater,
numCorrections,
convergenceTol,
- maxNumIterations,
+ numIterations,
regParam,
initialWeightsWithIntercept)
@@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
squaredL2Updater,
numCorrections,
convergenceTol,
- maxNumIterations,
+ numIterations,
regParam,
initialWeightsWithIntercept)
@@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
squaredL2Updater,
numCorrections,
convergenceTol,
- maxNumIterations,
+ numIterations,
regParam,
initialWeightsWithIntercept)
@@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
// Prepare another non-zero weights to compare the loss in the first iteration.
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
val convergenceTol = 1e-12
- val maxNumIterations = 10
+ val numIterations = 10
val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater)
.setNumCorrections(numCorrections)
.setConvergenceTol(convergenceTol)
- .setMaxNumIterations(maxNumIterations)
+ .setNumIterations(numIterations)
.setRegParam(regParam)
val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept)
@@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
.setNumCorrections(1)
.setConvergenceTol(1e-12)
- .setMaxNumIterations(1)
+ .setNumIterations(1)
.setRegParam(1.0)
val random = new Random(0)
// If we serialize data directly in the task closure, the size of the serialized task would be