aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala200
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala43
2 files changed, 148 insertions, 95 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 486043e8d9..dad8dfc84e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -276,113 +276,123 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
- if (numInvalid != 0) {
- val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
- s"Found $numInvalid invalid labels."
- logError(msg)
- throw new SparkException(msg)
- }
-
- if (numClasses > 2) {
- val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
- s"binary classification. Found $numClasses in the input dataset."
- logError(msg)
- throw new SparkException(msg)
- }
+ val (coefficients, intercept, objectiveHistory) = {
+ if (numInvalid != 0) {
+ val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
+ s"Found $numInvalid invalid labels."
+ logError(msg)
+ throw new SparkException(msg)
+ }
- val featuresMean = summarizer.mean.toArray
- val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+ if (numClasses > 2) {
+ val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
+ s"binary classification. Found $numClasses in the input dataset."
+ logError(msg)
+ throw new SparkException(msg)
+ } else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
+ logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
+ s"zeros and the intercept will be positive infinity; as a result, " +
+ s"training is not needed.")
+ (Vectors.sparse(numFeatures, Seq()), Double.PositiveInfinity, Array.empty[Double])
+ } else if ($(fitIntercept) && numClasses == 1) {
+ logWarning(s"All labels are zero and fitIntercept=true, so the coefficients will be " +
+ s"zeros and the intercept will be negative infinity; as a result, " +
+ s"training is not needed.")
+ (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
+ } else {
+ val featuresMean = summarizer.mean.toArray
+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
- val regParamL1 = $(elasticNetParam) * $(regParam)
- val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
+ val regParamL1 = $(elasticNetParam) * $(regParam)
+ val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
- val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization),
- featuresStd, featuresMean, regParamL2)
+ val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
+ $(standardization), featuresStd, featuresMean, regParamL2)
- val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
- new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
- } else {
- def regParamL1Fun = (index: Int) => {
- // Remove the L1 penalization on the intercept
- if (index == numFeatures) {
- 0.0
+ val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
+ new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- if ($(standardization)) {
- regParamL1
- } else {
- // If `standardization` is false, we still standardize the data
- // to improve the rate of convergence; as a result, we have to
- // perform this reverse standardization by penalizing each component
- // differently to get effectively the same objective function when
- // the training dataset is not standardized.
- if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+ def regParamL1Fun = (index: Int) => {
+ // Remove the L1 penalization on the intercept
+ if (index == numFeatures) {
+ 0.0
+ } else {
+ if ($(standardization)) {
+ regParamL1
+ } else {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+ }
+ }
}
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}
- }
- new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
- }
-
- val initialCoefficientsWithIntercept =
- Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
-
- if ($(fitIntercept)) {
- /*
- For binary logistic regression, when we initialize the coefficients as zeros,
- it will converge faster if we initialize the intercept such that
- it follows the distribution of the labels.
-
- {{{
- P(0) = 1 / (1 + \exp(b)), and
- P(1) = \exp(b) / (1 + \exp(b))
- }}}, hence
- {{{
- b = \log{P(1) / P(0)} = \log{count_1 / count_0}
- }}}
- */
- initialCoefficientsWithIntercept.toArray(numFeatures)
- = math.log(histogram(1) / histogram(0))
- }
- val states = optimizer.iterations(new CachedDiffFunction(costFun),
- initialCoefficientsWithIntercept.toBreeze.toDenseVector)
+ val initialCoefficientsWithIntercept =
+ Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
+
+ if ($(fitIntercept)) {
+ /*
+ For binary logistic regression, when we initialize the coefficients as zeros,
+ it will converge faster if we initialize the intercept such that
+ it follows the distribution of the labels.
+
+ {{{
+ P(0) = 1 / (1 + \exp(b)), and
+ P(1) = \exp(b) / (1 + \exp(b))
+ }}}, hence
+ {{{
+ b = \log{P(1) / P(0)} = \log{count_1 / count_0}
+ }}}
+ */
+ initialCoefficientsWithIntercept.toArray(numFeatures) = math.log(
+ histogram(1) / histogram(0))
+ }
- val (coefficients, intercept, objectiveHistory) = {
- /*
- Note that in Logistic Regression, the objective history (loss + regularization)
- is log-likelihood which is invariance under feature standardization. As a result,
- the objective history from optimizer is the same as the one in the original space.
- */
- val arrayBuilder = mutable.ArrayBuilder.make[Double]
- var state: optimizer.State = null
- while (states.hasNext) {
- state = states.next()
- arrayBuilder += state.adjustedValue
- }
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ initialCoefficientsWithIntercept.toBreeze.toDenseVector)
+
+ /*
+ Note that in Logistic Regression, the objective history (loss + regularization)
+ is log-likelihood which is invariance under feature standardization. As a result,
+ the objective history from optimizer is the same as the one in the original space.
+ */
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
- if (state == null) {
- val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
- throw new SparkException(msg)
- }
+ if (state == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ logError(msg)
+ throw new SparkException(msg)
+ }
- /*
- The coefficients are trained in the scaled space; we're converting them back to
- the original space.
- Note that the intercept in scaled space and original space is the same;
- as a result, no scaling is needed.
- */
- val rawCoefficients = state.x.toArray.clone()
- var i = 0
- while (i < numFeatures) {
- rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
- i += 1
- }
+ /*
+ The coefficients are trained in the scaled space; we're converting them back to
+ the original space.
+ Note that the intercept in scaled space and original space is the same;
+ as a result, no scaling is needed.
+ */
+ val rawCoefficients = state.x.toArray.clone()
+ var i = 0
+ while (i < numFeatures) {
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ i += 1
+ }
- if ($(fitIntercept)) {
- (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
- arrayBuilder.result())
- } else {
- (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
+ if ($(fitIntercept)) {
+ (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
+ arrayBuilder.result())
+ } else {
+ (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ff0d0ff771..972c0868a4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -883,6 +884,48 @@ class LogisticRegressionSuite
assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
}
+ test("logistic regression with all labels the same") {
+ val sameLabels = dataset
+ .withColumn("zeroLabel", lit(0.0))
+ .withColumn("oneLabel", lit(1.0))
+
+ // fitIntercept=true
+ val lrIntercept = new LogisticRegression()
+ .setFitIntercept(true)
+ .setMaxIter(3)
+
+ val allZeroInterceptModel = lrIntercept
+ .setLabelCol("zeroLabel")
+ .fit(sameLabels)
+ assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+ assert(allZeroInterceptModel.intercept === Double.NegativeInfinity)
+ assert(allZeroInterceptModel.summary.totalIterations === 0)
+
+ val allOneInterceptModel = lrIntercept
+ .setLabelCol("oneLabel")
+ .fit(sameLabels)
+ assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
+ assert(allOneInterceptModel.intercept === Double.PositiveInfinity)
+ assert(allOneInterceptModel.summary.totalIterations === 0)
+
+ // fitIntercept=false
+ val lrNoIntercept = new LogisticRegression()
+ .setFitIntercept(false)
+ .setMaxIter(3)
+
+ val allZeroNoInterceptModel = lrNoIntercept
+ .setLabelCol("zeroLabel")
+ .fit(sameLabels)
+ assert(allZeroNoInterceptModel.intercept === 0.0)
+ assert(allZeroNoInterceptModel.summary.totalIterations > 0)
+
+ val allOneNoInterceptModel = lrNoIntercept
+ .setLabelCol("oneLabel")
+ .fit(sameLabels)
+ assert(allOneNoInterceptModel.intercept === 0.0)
+ assert(allOneNoInterceptModel.summary.totalIterations > 0)
+ }
+
test("read/write") {
def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = {
assert(model.intercept === model2.intercept)