aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2016-01-21 17:24:48 -0800
committerDB Tsai <dbt@netflix.com>2016-01-21 17:24:48 -0800
commitb4574e387d0124667bdbb35f8c7c3e2065b14ba9 (patch)
tree78e45a4ee0cc07ae17a18b4ecbeece5c7e4588b3 /mllib
parent85200c09adc6eb98fadb8505f55cb44e3d8b3390 (diff)
downloadspark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.tar.gz
spark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.tar.bz2
spark-b4574e387d0124667bdbb35f8c7c3e2065b14ba9.zip
[SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue
When all labels are the same, it's a dangerous ground for LogisticRegression without intercept to converge. GLMNET doesn't support this case, and will just exit. GLM can train, but will have a warning message saying the algorithm doesn't converge. Author: DB Tsai <dbt@netflix.com> Closes #10862 from dbtsai/add-tests.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index dad8dfc84e..c98a78a515 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -300,6 +300,14 @@ class LogisticRegression @Since("1.2.0") (
s"training is not needed.")
(Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
} else {
+ if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
+ logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " +
+ s"so the algorithm may not converge.")
+ } else if (!$(fitIntercept) && numClasses == 1) {
+ logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " +
+ s"so the algorithm may not converge.")
+ }
+
val featuresMean = summarizer.mean.toArray
val featuresStd = summarizer.variance.toArray.map(math.sqrt)