diff options
author | DB Tsai <dbt@netflix.com> | 2016-01-21 17:24:48 -0800 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2016-01-21 17:24:48 -0800 |
commit | b4574e387d0124667bdbb35f8c7c3e2065b14ba9 (patch) | |
tree | 78e45a4ee0cc07ae17a18b4ecbeece5c7e4588b3 /mllib | |
parent | 85200c09adc6eb98fadb8505f55cb44e3d8b3390 (diff) | |
download | spark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.tar.gz spark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.tar.bz2 spark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.zip |
[SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue
When all labels are the same, it's a dangerous ground for LogisticRegression without intercept to converge. GLMNET doesn't support this case, and will just exit. GLM can train, but will have a warning message saying the algorithm doesn't converge.
Author: DB Tsai <dbt@netflix.com>
Closes #10862 from dbtsai/add-tests.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index dad8dfc84e..c98a78a515 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -300,6 +300,14 @@ class LogisticRegression @Since("1.2.0") ( s"training is not needed.") (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) } else { + if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } else if (!$(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } + val featuresMean = summarizer.mean.toArray val featuresStd = summarizer.variance.toArray.map(math.sqrt) |