aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-07-07 15:46:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-07 15:46:44 -0700
commit3bf20c27ff3cb3a32bfc3a44e08a57865957c117 (patch)
treea40ab74ff4d59eee13206380abb549553639667d /mllib
parent35d781e71b68eb6da7f49fdae40fa6c4f8e27060 (diff)
downloadspark-3bf20c27ff3cb3a32bfc3a44e08a57865957c117.tar.gz
spark-3bf20c27ff3cb3a32bfc3a44e08a57865957c117.tar.bz2
spark-3bf20c27ff3cb3a32bfc3a44e08a57865957c117.zip
[SPARK-8845] [ML] ML use of Breeze optimization: use adjustedValue instead of value
In LinearRegression and LogisticRegression, we use Breeze's optimizers (LBFGS and OWLQN). We check the State.value to see the current objective. However, Breeze's documentation makes it sound like value and adjustedValue differ for some optimizers, possibly including OWLQN: https://github.com/scalanlp/breeze/blob/26faf622862e8d7a42a401aef601347aac655f2b/math/src/main/scala/breeze/optimize/FirstOrderMinimizer.scala#L36 If that is the case, then we should use adjustedValue instead of value. This is relevant to SPARK-8538 and SPARK-8539, where we will provide the objective trace to the user. Author: DB Tsai <dbt@netflix.com> Closes #7245 from dbtsai/SPARK-8845 and squashes the following commits: fa4c91e [DB Tsai] address feedback e6caac1 [DB Tsai] java style multiline comment b10c574 [DB Tsai] address feedback c9ff81e [DB Tsai] first commit
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala83
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala57
2 files changed, 80 insertions, 60 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 2e6eedd45a..3967151f76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -116,7 +116,7 @@ class LogisticRegression(override val uid: String)
case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
(label: Double, features: Vector)) =>
(summarizer.add(features), labelSummarizer.add(label))
- },
+ },
combOp = (c1, c2) => (c1, c2) match {
case ((summarizer1: MultivariateOnlineSummarizer,
classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
@@ -166,18 +166,18 @@ class LogisticRegression(override val uid: String)
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
if ($(fitIntercept)) {
- /**
- * For binary logistic regression, when we initialize the weights as zeros,
- * it will converge faster if we initialize the intercept such that
- * it follows the distribution of the labels.
- *
- * {{{
- * P(0) = 1 / (1 + \exp(b)), and
- * P(1) = \exp(b) / (1 + \exp(b))
- * }}}, hence
- * {{{
- * b = \log{P(1) / P(0)} = \log{count_1 / count_0}
- * }}}
+ /*
+ For binary logistic regression, when we initialize the weights as zeros,
+ it will converge faster if we initialize the intercept such that
+ it follows the distribution of the labels.
+
+ {{{
+ P(0) = 1 / (1 + \exp(b)), and
+ P(1) = \exp(b) / (1 + \exp(b))
+ }}}, hence
+ {{{
+ b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+ }}}
*/
initialWeightsWithIntercept.toArray(numFeatures)
= math.log(histogram(1).toDouble / histogram(0).toDouble)
@@ -186,39 +186,48 @@ class LogisticRegression(override val uid: String)
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialWeightsWithIntercept.toBreeze.toDenseVector)
- var state = states.next()
- val lossHistory = mutable.ArrayBuilder.make[Double]
+ val (weights, intercept, objectiveHistory) = {
+ /*
+ Note that in Logistic Regression, the objective history (loss + regularization)
+ is log-likelihood which is invariance under feature standardization. As a result,
+ the objective history from optimizer is the same as the one in the original space.
+ */
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
- while (states.hasNext) {
- lossHistory += state.value
- state = states.next()
- }
- lossHistory += state.value
+ if (state == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ logError(msg)
+ throw new SparkException(msg)
+ }
- // The weights are trained in the scaled space; we're converting them back to
- // the original space.
- val weightsWithIntercept = {
+ /*
+ The weights are trained in the scaled space; we're converting them back to
+ the original space.
+ Note that the intercept in scaled space and original space is the same;
+ as a result, no scaling is needed.
+ */
val rawWeights = state.x.toArray.clone()
var i = 0
- // Note that the intercept in scaled space and original space is the same;
- // as a result, no scaling is needed.
while (i < numFeatures) {
rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
i += 1
}
- Vectors.dense(rawWeights)
+
+ if ($(fitIntercept)) {
+ (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result())
+ } else {
+ (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result())
+ }
}
if (handlePersistence) instances.unpersist()
- val (weights, intercept) = if ($(fitIntercept)) {
- (Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)),
- weightsWithIntercept(weightsWithIntercept.size - 1))
- } else {
- (weightsWithIntercept, 0.0)
- }
-
- new LogisticRegressionModel(uid, weights.compressed, intercept)
+ copyValues(new LogisticRegressionModel(uid, weights, intercept))
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -423,16 +432,12 @@ private class LogisticAggregator(
require(dim == data.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${data.size}.")
- val dataSize = data.size
-
val localWeightsArray = weightsArray
val localGradientSumArray = gradientSumArray
numClasses match {
case 2 =>
- /**
- * For Binary Logistic Regression.
- */
+ // For Binary Logistic Regression.
val margin = - {
var sum = 0.0
data.foreachActive { (index, value) =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 1b1d7299fb..f672c96576 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
@@ -132,7 +132,6 @@ class LinearRegression(override val uid: String)
val numFeatures = summarizer.mean.size
val yMean = statCounter.mean
val yStd = math.sqrt(statCounter.variance)
- // look at glmnet5.m L761 maaaybe that has info
// If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed.
@@ -162,21 +161,34 @@ class LinearRegression(override val uid: String)
}
val initialWeights = Vectors.zeros(numFeatures)
- val states =
- optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
-
- var state = states.next()
- val lossHistory = mutable.ArrayBuilder.make[Double]
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ initialWeights.toBreeze.toDenseVector)
+
+ val (weights, objectiveHistory) = {
+ /*
+ Note that in Linear Regression, the objective history (loss + regularization) returned
+ from optimizer is computed in the scaled space given by the following formula.
+ {{{
+ L = 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2 + regTerms
+ }}}
+ */
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
- while (states.hasNext) {
- lossHistory += state.value
- state = states.next()
- }
- lossHistory += state.value
+ if (state == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ logError(msg)
+ throw new SparkException(msg)
+ }
- // The weights are trained in the scaled space; we're converting them back to
- // the original space.
- val weights = {
+ /*
+ The weights are trained in the scaled space; we're converting them back to
+ the original space.
+ */
val rawWeights = state.x.toArray.clone()
var i = 0
val len = rawWeights.length
@@ -184,17 +196,20 @@ class LinearRegression(override val uid: String)
rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
i += 1
}
- Vectors.dense(rawWeights)
+
+ (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
}
- // The intercept in R's GLMNET is computed using closed form after the coefficients are
- // converged. See the following discussion for detail.
- // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+ /*
+ The intercept in R's GLMNET is computed using closed form after the coefficients are
+ converged. See the following discussion for detail.
+ http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+ */
val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
+
if (handlePersistence) instances.unpersist()
- // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
- copyValues(new LinearRegressionModel(uid, weights.compressed, intercept))
+ copyValues(new LinearRegressionModel(uid, weights, intercept))
}
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)