aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-07-02 15:00:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-02 15:00:13 -0700
commit7d9cc9673e47227f58411ca1f4e647cd8233a219 (patch)
tree14260929c5d9d592ab0149d6cc909d4721ecfb1d /mllib
parent52508beb650a863ed5c89384414b3b7675cac11e (diff)
downloadspark-7d9cc9673e47227f58411ca1f4e647cd8233a219.tar.gz
spark-7d9cc9673e47227f58411ca1f4e647cd8233a219.tar.bz2
spark-7d9cc9673e47227f58411ca1f4e647cd8233a219.zip
[SPARK-3382] [MLLIB] GradientDescent convergence tolerance
GrandientDescent can receive convergence tolerance value. Default value is 0.0. When loss value becomes less than the tolerance which is set by user, iteration is terminated. Author: lewuathe <lewuathe@me.com> Closes #3636 from Lewuathe/gd-convergence-tolerance and squashes the following commits: 0b8a9a8 [lewuathe] Update doc ce91b15 [lewuathe] Merge branch 'master' into gd-convergence-tolerance 4f22c2b [lewuathe] Modify based on SPARK-1503 5e47b82 [lewuathe] Merge branch 'master' into gd-convergence-tolerance abadb7e [lewuathe] Fix LassoSuite 8fadebd [lewuathe] Fix failed unit tests ee5de46 [lewuathe] Merge branch 'master' into gd-convergence-tolerance 8313ba2 [lewuathe] Fix styles 0ead94c [lewuathe] Merge branch 'master' into gd-convergence-tolerance a94cfd5 [lewuathe] Modify some styles 3aef0a2 [lewuathe] Modify converged logic to do relative comparison f7b19d5 [lewuathe] [SPARK-3382] Clarify comparison logic e6c9cd2 [lewuathe] [SPARK-3382] Compare with the diff of solution vector 4b125d2 [lewuathe] [SPARK3382] Fix scala style e7c10dd [lewuathe] [SPARK-3382] format improvements f867eea [lewuathe] [SPARK-3382] Modify warning message statements b9d5e61 [lewuathe] [SPARK-3382] should compare diff inside loss history and convergence tolerance 5433f71 [lewuathe] [SPARK-3382] GradientDescent convergence tolerance
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala105
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala45
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala1
7 files changed, 144 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 06e45e10c5..ab7611fd07 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -19,13 +19,14 @@ package org.apache.spark.mllib.optimization
import scala.collection.mutable.ArrayBuffer
-import breeze.linalg.{DenseVector => BDV}
+import breeze.linalg.{DenseVector => BDV, norm}
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
/**
* Class used to solve an optimization problem using Gradient Descent.
* @param gradient Gradient function to be used.
@@ -38,6 +39,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
private var numIterations: Int = 100
private var regParam: Double = 0.0
private var miniBatchFraction: Double = 1.0
+ private var convergenceTol: Double = 0.001
/**
* Set the initial step size of SGD for the first step. Default 1.0.
@@ -76,6 +78,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
}
/**
+ * Set the convergence tolerance. Default 0.001
+ * convergenceTol is a condition which decides iteration termination.
+ * The end of iteration is decided based on below logic.
+ * - If the norm of the new solution vector is >1, the diff of solution vectors
+ * is compared to relative tolerance which means normalizing by the norm of
+ * the new solution vector.
+ * - If the norm of the new solution vector is <=1, the diff of solution vectors
+ * is compared to absolute tolerance which is not normalizing.
+ * Must be between 0.0 and 1.0 inclusively.
+ */
+ def setConvergenceTol(tolerance: Double): this.type = {
+ require(0.0 <= tolerance && tolerance <= 1.0)
+ this.convergenceTol = tolerance
+ this
+ }
+
+ /**
* Set the gradient function (of the loss function of one single data example)
* to be used for SGD.
*/
@@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
numIterations,
regParam,
miniBatchFraction,
- initialWeights)
+ initialWeights,
+ convergenceTol)
weights
}
@@ -131,17 +151,20 @@ object GradientDescent extends Logging {
* Sampling, and averaging the subgradients over this subset is performed using one standard
* spark map-reduce in each iteration.
*
- * @param data - Input data for SGD. RDD of the set of data examples, each of
- * the form (label, [feature values]).
- * @param gradient - Gradient object (used to compute the gradient of the loss function of
- * one single data example)
- * @param updater - Updater function to actually perform a gradient step in a given direction.
- * @param stepSize - initial step size for the first step
- * @param numIterations - number of iterations that SGD should be run.
- * @param regParam - regularization parameter
- * @param miniBatchFraction - fraction of the input data set that should be used for
- * one iteration of SGD. Default value 1.0.
- *
+ * @param data Input data for SGD. RDD of the set of data examples, each of
+ * the form (label, [feature values]).
+ * @param gradient Gradient object (used to compute the gradient of the loss function of
+ * one single data example)
+ * @param updater Updater function to actually perform a gradient step in a given direction.
+ * @param stepSize initial step size for the first step
+ * @param numIterations number of iterations that SGD should be run.
+ * @param regParam regularization parameter
+ * @param miniBatchFraction fraction of the input data set that should be used for
+ * one iteration of SGD. Default value 1.0.
+ * @param convergenceTol Minibatch iteration will end before numIterations if the relative
+ * difference between the current weight and the previous weight is less
+ * than this value. In measuring convergence, L2 norm is calculated.
+ * Default value 0.001. Must be between 0.0 and 1.0 inclusively.
* @return A tuple containing two elements. The first element is a column matrix containing
* weights for every feature, and the second element is an array containing the
* stochastic loss computed for every iteration.
@@ -154,9 +177,20 @@ object GradientDescent extends Logging {
numIterations: Int,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): (Vector, Array[Double]) = {
+ initialWeights: Vector,
+ convergenceTol: Double): (Vector, Array[Double]) = {
+
+ // convergenceTol should be set with non minibatch settings
+ if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
+ logWarning("Testing against a convergenceTol when using miniBatchFraction " +
+ "< 1.0 can be unstable because of the stochasticity in sampling.")
+ }
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
+ // Record previous weight and current one to calculate solution vector difference
+
+ var previousWeights: Option[Vector] = None
+ var currentWeights: Option[Vector] = None
val numExamples = data.count()
@@ -181,7 +215,9 @@ object GradientDescent extends Logging {
var regVal = updater.compute(
weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
- for (i <- 1 to numIterations) {
+ var converged = false // indicates whether converged based on convergenceTol
+ var i = 1
+ while (!converged && i <= numIterations) {
val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
@@ -204,12 +240,21 @@ object GradientDescent extends Logging {
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
- weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
+ weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
+ stepSize, i, regParam)
weights = update._1
regVal = update._2
+
+ previousWeights = currentWeights
+ currentWeights = Some(weights)
+ if (previousWeights != None && currentWeights != None) {
+ converged = isConverged(previousWeights.get,
+ currentWeights.get, convergenceTol)
+ }
} else {
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
}
+ i += 1
}
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
@@ -218,4 +263,32 @@ object GradientDescent extends Logging {
(weights, stochasticLossHistory.toArray)
}
+
+ def runMiniBatchSGD(
+ data: RDD[(Double, Vector)],
+ gradient: Gradient,
+ updater: Updater,
+ stepSize: Double,
+ numIterations: Int,
+ regParam: Double,
+ miniBatchFraction: Double,
+ initialWeights: Vector): (Vector, Array[Double]) =
+ GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
+ regParam, miniBatchFraction, initialWeights, 0.001)
+
+
+ private def isConverged(
+ previousWeights: Vector,
+ currentWeights: Vector,
+ convergenceTol: Double): Boolean = {
+ // To compare with convergence tolerance.
+ val previousBDV = previousWeights.toBreeze.toDenseVector
+ val currentBDV = currentWeights.toBreeze.toDenseVector
+
+ // This represents the difference of updated weights in the iteration.
+ val solutionVecDiff: Double = norm(previousBDV - currentBDV)
+
+ solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index 235e043c77..c6d04464a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -85,4 +85,10 @@ class StreamingLinearRegressionWithSGD private[mllib] (
this
}
+ /** Set the convergence tolerance. */
+ def setConvergenceTol(tolerance: Double): this.type = {
+ this.algorithm.optimizer.setConvergenceTol(tolerance)
+ this
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index e8f3d0c4db..2473510e13 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -196,6 +196,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
.setStepSize(10.0)
.setRegParam(0.0)
.setNumIterations(20)
+ .setConvergenceTol(0.0005)
val model = lr.run(testRDD)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a5a59e9fad..13b754a039 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
object GradientDescentSuite {
@@ -82,11 +82,11 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Vectors.dense(1.0 +: features.toArray)
+ label -> MLUtils.appendBias(features)
}
val dataRDD = sc.parallelize(data, 2).cache()
- val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray)
+ val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
val (_, loss) = GradientDescent.runMiniBatchSGD(
dataRDD,
@@ -139,6 +139,45 @@ class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with
"The different between newWeights with/without regularization " +
"should be initialWeightsWithIntercept.")
}
+
+ test("iteration should end with convergence tolerance") {
+ val nPoints = 10000
+ val A = 2.0
+ val B = -1.5
+
+ val initialB = -1.0
+ val initialWeights = Array(initialB)
+
+ val gradient = new LogisticGradient()
+ val updater = new SimpleUpdater()
+ val stepSize = 1.0
+ val numIterations = 10
+ val regParam = 0
+ val miniBatchFrac = 1.0
+ val convergenceTolerance = 5.0e-1
+
+ // Add a extra variable consisting of all 1.0's for the intercept.
+ val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
+ val data = testData.map { case LabeledPoint(label, features) =>
+ label -> MLUtils.appendBias(features)
+ }
+
+ val dataRDD = sc.parallelize(data, 2).cache()
+ val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
+
+ val (_, loss) = GradientDescent.runMiniBatchSGD(
+ dataRDD,
+ gradient,
+ updater,
+ stepSize,
+ numIterations,
+ regParam,
+ miniBatchFrac,
+ initialWeightsWithIntercept,
+ convergenceTolerance)
+
+ assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early")
+ }
}
class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index d07b9d5b89..75ae0eb32f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -122,7 +122,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
numGDIterations,
regParam,
miniBatchFrac,
- initialWeightsWithIntercept)
+ initialWeightsWithIntercept,
+ convergenceTol)
assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
"The first losses of LBFGS and GD should be the same.")
@@ -221,7 +222,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
numGDIterations,
regParam,
miniBatchFrac,
- initialWeightsWithIntercept)
+ initialWeightsWithIntercept,
+ convergenceTol)
// for class LBFGS and the optimize method, we only look at the weights
assert(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 08a152ffc7..39537e7bb4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -100,7 +100,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005)
val model = ls.run(testRDD, initialWeights)
val weight0 = model.weights(0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index f5e2d31056..a2a4c5f6b8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -53,6 +53,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.2)
.setNumIterations(25)
+ .setConvergenceTol(0.0001)
// generate sequence of simulated data
val numBatches = 10