aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala89
1 files changed, 66 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3967151f76..8fc9199fb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
-import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.{Logging, SparkException}
@@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasThreshold
+ with HasThreshold with HasStandardization
/**
* :: Experimental ::
@@ -98,6 +98,18 @@ class LogisticRegression(override val uid: String)
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
+ /**
+ * Whether to standardize the training features before fitting the model.
+ * The coefficients of models will be always returned on the original scale,
+ * so it will be transparent for users. Note that when no regularization,
+ * with or without standardization, the models should be always converged to
+ * the same solution.
+ * Default is true.
+ * @group setParam
+ * */
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+ setDefault(standardization -> true)
+
/** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
setDefault(threshold -> 0.5)
@@ -149,15 +161,28 @@ class LogisticRegression(override val uid: String)
val regParamL1 = $(elasticNetParam) * $(regParam)
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
- val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+ val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
featuresStd, featuresMean, regParamL2)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- // Remove the L1 penalization on the intercept
def regParamL1Fun = (index: Int) => {
- if (index == numFeatures) 0.0 else regParamL1
+ // Remove the L1 penalization on the intercept
+ if (index == numFeatures) {
+ 0.0
+ } else {
+ if ($(standardization)) {
+ regParamL1
+ } else {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+ }
+ }
}
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}
@@ -523,11 +548,13 @@ private class LogisticCostFun(
data: RDD[(Double, Vector)],
numClasses: Int,
fitIntercept: Boolean,
+ standardization: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
regParamL2: Double) extends DiffFunction[BDV[Double]] {
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ val numFeatures = featuresStd.length
val w = Vectors.fromBreeze(weights)
val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
@@ -539,27 +566,43 @@ private class LogisticCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
- // regVal is the sum of weight squares for L2 regularization
- val norm = if (regParamL2 == 0.0) {
- 0.0
- } else if (fitIntercept) {
- brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0)
- } else {
- brzNorm(weights, 2.0)
- }
- val regVal = 0.5 * regParamL2 * norm * norm
+ val totalGradientArray = logisticAggregator.gradient.toArray
- val loss = logisticAggregator.loss + regVal
- val gradient = logisticAggregator.gradient
-
- if (fitIntercept) {
- val wArray = w.toArray.clone()
- wArray(wArray.length - 1) = 0.0
- axpy(regParamL2, Vectors.dense(wArray), gradient)
+ // regVal is the sum of weight squares excluding intercept for L2 regularization.
+ val regVal = if (regParamL2 == 0.0) {
+ 0.0
} else {
- axpy(regParamL2, w, gradient)
+ var sum = 0.0
+ w.foreachActive { (index, value) =>
+ // If `fitIntercept` is true, the last term which is intercept doesn't
+ // contribute to the regularization.
+ if (index != numFeatures) {
+ // The following code will compute the loss of the regularization; also
+ // the gradient of the regularization, and add back to totalGradientArray.
+ sum += {
+ if (standardization) {
+ totalGradientArray(index) += regParamL2 * value
+ value * value
+ } else {
+ if (featuresStd(index) != 0.0) {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ val temp = value / (featuresStd(index) * featuresStd(index))
+ totalGradientArray(index) += regParamL2 * temp
+ value * temp
+ } else {
+ 0.0
+ }
+ }
+ }
+ }
+ }
+ 0.5 * regParamL2 * sum
}
- (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+ (logisticAggregator.loss + regVal, new BDV(totalGradientArray))
}
}