diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 89 |
1 files changed, 66 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3967151f76..8fc9199fb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.{Logging, SparkException} @@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasThreshold + with HasThreshold with HasStandardization /** * :: Experimental :: @@ -98,6 +98,18 @@ class LogisticRegression(override val uid: String) def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) + /** + * Whether to standardize the training features before fitting the model. + * The coefficients of models will be always returned on the original scale, + * so it will be transparent for users. Note that when no regularization, + * with or without standardization, the models should be always converged to + * the same solution. + * Default is true. + * @group setParam + * */ + def setStandardization(value: Boolean): this.type = set(standardization, value) + setDefault(standardization -> true) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) setDefault(threshold -> 0.5) @@ -149,15 +161,28 @@ class LogisticRegression(override val uid: String) val regParamL1 = $(elasticNetParam) * $(regParam) val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) - val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), + val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization), featuresStd, featuresMean, regParamL2) val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - // Remove the L1 penalization on the intercept def regParamL1Fun = (index: Int) => { - if (index == numFeatures) 0.0 else regParamL1 + // Remove the L1 penalization on the intercept + if (index == numFeatures) { + 0.0 + } else { + if ($(standardization)) { + regParamL1 + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + } + } } new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } @@ -523,11 +548,13 @@ private class LogisticCostFun( data: RDD[(Double, Vector)], numClasses: Int, fitIntercept: Boolean, + standardization: Boolean, featuresStd: Array[Double], featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val numFeatures = featuresStd.length val w = Vectors.fromBreeze(weights) val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, @@ -539,27 +566,43 @@ private class LogisticCostFun( case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) }) - // regVal is the sum of weight squares for L2 regularization - val norm = if (regParamL2 == 0.0) { - 0.0 - } else if (fitIntercept) { - brzNorm(Vectors.dense(weights.toArray.slice(0, weights.size -1)).toBreeze, 2.0) - } else { - brzNorm(weights, 2.0) - } - val regVal = 0.5 * regParamL2 * norm * norm + val totalGradientArray = logisticAggregator.gradient.toArray - val loss = logisticAggregator.loss + regVal - val gradient = logisticAggregator.gradient - - if (fitIntercept) { - val wArray = w.toArray.clone() - wArray(wArray.length - 1) = 0.0 - axpy(regParamL2, Vectors.dense(wArray), gradient) + // regVal is the sum of weight squares excluding intercept for L2 regularization. + val regVal = if (regParamL2 == 0.0) { + 0.0 } else { - axpy(regParamL2, w, gradient) + var sum = 0.0 + w.foreachActive { (index, value) => + // If `fitIntercept` is true, the last term which is intercept doesn't + // contribute to the regularization. + if (index != numFeatures) { + // The following code will compute the loss of the regularization; also + // the gradient of the regularization, and add back to totalGradientArray. + sum += { + if (standardization) { + totalGradientArray(index) += regParamL2 * value + value * value + } else { + if (featuresStd(index) != 0.0) { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val temp = value / (featuresStd(index) * featuresStd(index)) + totalGradientArray(index) += regParamL2 * temp + value * temp + } else { + 0.0 + } + } + } + } + } + 0.5 * regParamL2 * sum } - (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) } } |