aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala34
2 files changed, 35 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 8f187c9df5..7bbed9c8fd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
* Smaller value will lead to higher accuracy with the cost of more iterations.
*/
- def setConvergenceTol(tolerance: Int): this.type = {
+ def setConvergenceTol(tolerance: Double): this.type = {
this.convergenceTol = tolerance
this
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 4b1850659a..fe7a9033cd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
assert(lossLBFGS3.length == 6)
assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol)
}
+
+ test("Optimize via class LBFGS.") {
+ val regParam = 0.2
+
+ // Prepare another non-zero weights to compare the loss in the first iteration.
+ val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
+ val convergenceTol = 1e-12
+ val maxNumIterations = 10
+
+ val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater)
+ .setNumCorrections(numCorrections)
+ .setConvergenceTol(convergenceTol)
+ .setMaxNumIterations(maxNumIterations)
+ .setRegParam(regParam)
+
+ val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept)
+
+ val numGDIterations = 50
+ val stepSize = 1.0
+ val (weightGD, _) = GradientDescent.runMiniBatchSGD(
+ dataRDD,
+ gradient,
+ squaredL2Updater,
+ stepSize,
+ numGDIterations,
+ regParam,
+ miniBatchFrac,
+ initialWeightsWithIntercept)
+
+ // for class LBFGS and the optimize method, we only look at the weights
+ assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
+ compareDouble(weightLBFGS(1), weightGD(1), 0.02),
+ "The weight differences between LBFGS and GD should be within 2%.")
+ }
}