aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-07-14 09:11:04 +0100
committerSean Owen <sowen@cloudera.com>2016-07-14 09:11:04 +0100
commit252d4f27f23b547777892bcea25a2cea62d8cbab (patch)
tree2d5d12dc618ba76ce627f3b0858d073323a433f6 /mllib
parentdb7317ac3c2fd2a11088d10060f168178dc99664 (diff)
downloadspark-252d4f27f23b547777892bcea25a2cea62d8cbab.tar.gz
spark-252d4f27f23b547777892bcea25a2cea62d8cbab.tar.bz2
spark-252d4f27f23b547777892bcea25a2cea62d8cbab.zip
[SPARK-16500][ML][MLLIB][OPTIMIZER] add LBFGS convergence warning for all used place in MLLib
## What changes were proposed in this pull request? Add warning_for the following case when LBFGS training not actually convergence: 1) LogisticRegression 2) AFTSurvivalRegression 3) LBFGS algorithm wrapper in mllib package ## How was this patch tested? N/A Author: WeichenXu <WeichenXu123@outlook.com> Closes #14157 from WeichenXu123/add_lbfgs_convergence_warning_for_all_used_place.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala6
3 files changed, 16 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index e157bdeb5b..4bab801bb3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -424,6 +424,11 @@ class LogisticRegression @Since("1.2.0") (
throw new SparkException(msg)
}
+ if (!state.actuallyConverged) {
+ logWarning("LogisticRegression training fininshed but the result " +
+ s"is not converged because: ${state.convergedReason.get.reason}")
+ }
+
/*
The coefficients are trained in the scaled space; we're converting them back to
the original space.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 7c51845a25..366448fc56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -245,6 +245,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
throw new SparkException(msg)
}
+ if (!state.actuallyConverged) {
+ logWarning("AFTSurvivalRegression training fininshed but the result " +
+ s"is not converged because: ${state.convergedReason.get.reason}")
+ }
+
state.x.toArray.clone()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index ec6ffe6e19..c61b2db6c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -212,6 +212,12 @@ object LBFGS extends Logging {
state = states.next()
}
lossHistory += state.value
+
+ if (!state.actuallyConverged) {
+ logWarning("LBFGS training fininshed but the result " +
+ s"is not converged because: ${state.convergedReason.get.reason}")
+ }
+
val weights = Vectors.fromBreeze(state.x)
val lossHistoryArray = lossHistory.result()