aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-10-08 11:27:46 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-08 11:27:46 -0700
commit22683560027806144c3a1141904b63eda86ae96c (patch)
treebb43a875ff9e629543349768f60826278ebd567e /mllib
parent0903c6489e9fa39db9575dace22a64015b9cd4c5 (diff)
downloadspark-22683560027806144c3a1141904b63eda86ae96c.tar.gz
spark-22683560027806144c3a1141904b63eda86ae96c.tar.bz2
spark-22683560027806144c3a1141904b63eda86ae96c.zip
[SPARK-7770] [ML] GBT validationTol change to compare with relative or absolute error
GBT compare ValidateError with tolerance switching between relative and absolute ones, where the former one is relative to the current loss on the training set. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8549 from yanboliang/spark-7770.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala15
2 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 95ed48cea6..66a07e3136 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging {
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol) {
+ if (bestValidateError - currentValidateError < validationTol * Math.max(
+ currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index b5c72fba3e..fc13bcfd8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param validationTol Useful when runWithValidation is used. If the error rate on the
- * validation input between two iterations is less than the validationTol
- * then stop. Ignored when
+ * @param validationTol validationTol is a condition which decides iteration termination when
+ * runWithValidation is used.
+ * The end of iteration is decided based on below logic:
+ * If the current loss on the validation set is > 0.01, the diff
+ * of validation error is compared to relative tolerance which is
+ * validationTol * (current loss on the validation set).
+ * If the current loss on the validation set is <= 0.01, the diff
+ * of validation error is compared to absolute tolerance which is
+ * validationTol * 0.01.
+ * Ignored when
* [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Since("1.2.0")
@@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") (
// Optional boosting parameters
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
- @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
+ @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable {
/**
* Check validity of parameters.