aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2014-08-14 11:56:13 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-14 11:56:13 -0700
commit96221067572e5955af1a7710b0cca33a73db4bd5 (patch)
tree9e844dfef5c98dafae19d3c1fd55a8eed3c7e570 /mllib
parenteaeb0f76fa0f103c7db0f3975cb8562715410973 (diff)
downloadspark-96221067572e5955af1a7710b0cca33a73db4bd5.tar.gz
spark-96221067572e5955af1a7710b0cca33a73db4bd5.tar.bz2
spark-96221067572e5955af1a7710b0cca33a73db4bd5.zip
[SPARK-2979][MLlib] Improve the convergence rate by minimizing the condition number
In theory, the scale of your inputs are irrelevant to logistic regression. You can "theoretically" multiply X1 by 1E6 and the estimate for β1 will adjust accordingly. It will be 1E-6 times smaller than the original β1, due to the invariance property of MLEs. However, during the optimization process, the convergence (rate) depends on the condition number of the training dataset. Scaling the variables often reduces this condition number, thus improving the convergence rate. Without reducing the condition number, some training datasets mixing the columns with different scales may not be able to converge. GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return the weights in the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf Here, if useFeatureScaling is enabled, we will standardize the training features by dividing the variance of each column (without subtracting the mean to densify the sparse vector), and train the model in the scaled space. Then we transform the coefficients from the scaled space to the original scale as GLMNET and LIBSVM do. Currently, it's only enabled in LogisticRegressionWithLBFGS. Author: DB Tsai <dbtsai@alpinenow.com> Closes #1897 from dbtsai/dbtsai-feature-scaling and squashes the following commits: f19fc02 [DB Tsai] Added more comments 1d85289 [DB Tsai] Improve the convergence rate by minimize the condition number in LOR with LBFGS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala57
3 files changed, 126 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 31d474a20f..6790c86f65 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -62,7 +62,7 @@ class LogisticRegressionModel (
override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
- val score = 1.0/ (1.0 + math.exp(-margin))
+ val score = 1.0 / (1.0 + math.exp(-margin))
threshold match {
case Some(t) => if (score < t) 0.0 else 1.0
case None => score
@@ -204,6 +204,8 @@ class LogisticRegressionWithLBFGS private (
*/
def this() = this(1E-4, 100, 0.0)
+ this.setFeatureScaling(true)
+
private val gradient = new LogisticGradient()
private val updater = new SimpleUpdater()
// Have to return new LBFGS object every time since users can reset the parameters anytime.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 54854252d7..20c1fdd226 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.regression
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.feature.StandardScaler
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
@@ -95,6 +96,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
protected var validateData: Boolean = true
/**
+ * Whether to perform feature scaling before model training to reduce the condition numbers
+ * which can significantly help the optimizer converging faster. The scaling correction will be
+ * translated back to resulting model weights, so it's transparent to users.
+ * Note: This technique is used in both libsvm and glmnet packages. Default false.
+ */
+ private var useFeatureScaling = false
+
+ /**
+ * Set if the algorithm should use feature scaling to improve the convergence during optimization.
+ */
+ private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = {
+ this.useFeatureScaling = useFeatureScaling
+ this
+ }
+
+ /**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Vector, intercept: Double): M
@@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
throw new SparkException("Input validation failed.")
}
+ /**
+ * Scaling columns to unit variance as a heuristic to reduce the condition number:
+ *
+ * During the optimization process, the convergence (rate) depends on the condition number of
+ * the training dataset. Scaling the variables often reduces this condition number
+ * heuristically, thus improving the convergence rate. Without reducing the condition number,
+ * some training datasets mixing the columns with different scales may not be able to converge.
+ *
+ * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return
+ * the weights in the original scale.
+ * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
+ *
+ * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing
+ * the variance of each column (without subtracting the mean), and train the model in the
+ * scaled space. Then we transform the coefficients from the scaled space to the original scale
+ * as GLMNET and LIBSVM do.
+ *
+ * Currently, it's only enabled in LogisticRegressionWithLBFGS
+ */
+ val scaler = if (useFeatureScaling) {
+ (new StandardScaler).fit(input.map(x => x.features))
+ } else {
+ null
+ }
+
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
+ if(useFeatureScaling) {
+ input.map(labeledPoint =>
+ (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features))))
+ } else {
+ input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
+ }
} else {
- input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
+ if (useFeatureScaling) {
+ input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features)))
+ } else {
+ input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
+ }
}
val initialWeightsWithIntercept = if (addIntercept) {
@@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
- val weights =
+ var weights =
if (addIntercept) {
Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
} else {
weightsWithIntercept
}
+ /**
+ * The weights and intercept are trained in the scaled space; we're converting them back to
+ * the original scale.
+ *
+ * Math shows that if we only perform standardization without subtracting means, the intercept
+ * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i
+ * is the coefficient in the original space, and v_i is the variance of the column i.
+ */
+ if (useFeatureScaling) {
+ weights = scaler.transform(weights)
+ }
+
createModel(weights, intercept)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 2289c6cdc1..bc05b20468 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -185,6 +185,63 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("numerical stability of scaling features using logistic regression with LBFGS") {
+ /**
+ * If we rescale the features, the condition number will be changed so the convergence rate
+ * and the solution will not equal to the original solution multiple by the scaling factor
+ * which it should be.
+ *
+ * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first,
+ * no matter how we multiple a scaling factor into the dataset, the convergence rate should be
+ * the same, and the solution should equal to the original solution multiple by the scaling
+ * factor.
+ */
+
+ val nPoints = 10000
+ val A = 2.0
+ val B = -1.5
+
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+ val initialWeights = Vectors.dense(0.0)
+
+ val testRDD1 = sc.parallelize(testData, 2)
+
+ val testRDD2 = sc.parallelize(
+ testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2)
+
+ val testRDD3 = sc.parallelize(
+ testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2)
+
+ testRDD1.cache()
+ testRDD2.cache()
+ testRDD3.cache()
+
+ val lrA = new LogisticRegressionWithLBFGS().setIntercept(true)
+ val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false)
+
+ val modelA1 = lrA.run(testRDD1, initialWeights)
+ val modelA2 = lrA.run(testRDD2, initialWeights)
+ val modelA3 = lrA.run(testRDD3, initialWeights)
+
+ val modelB1 = lrB.run(testRDD1, initialWeights)
+ val modelB2 = lrB.run(testRDD2, initialWeights)
+ val modelB3 = lrB.run(testRDD3, initialWeights)
+
+ // For model trained with feature standardization, the weights should
+ // be the same in the scaled space. Note that the weights here are already
+ // in the original space, we transform back to scaled space to compare.
+ assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01)
+ assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01)
+
+ // Training data with different scales without feature standardization
+ // will not yield the same result in the scaled space due to poor
+ // convergence rate.
+ assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1)
+ assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
+ }
+
}
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {