From 22683560027806144c3a1141904b63eda86ae96c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 8 Oct 2015 11:27:46 -0700 Subject: [SPARK-7770] [ML] GBT validationTol change to compare with relative or absolute error GBT compare ValidateError with tolerance switching between relative and absolute ones, where the former one is relative to the current loss on the training set. Author: Yanbo Liang Closes #8549 from yanboliang/spark-7770. --- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 3 ++- .../spark/mllib/tree/configuration/BoostingStrategy.scala | 15 +++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 95ed48cea6..66a07e3136 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging { validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() - if (bestValidateError - currentValidateError < validationTol) { + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { doneLearning = true } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index b5c72fba3e..fc13bcfd8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param validationTol Useful when runWithValidation is used. If the error rate on the - * validation input between two iterations is less than the validationTol - * then stop. Ignored when + * @param validationTol validationTol is a condition which decides iteration termination when + * runWithValidation is used. + * The end of iteration is decided based on below logic: + * If the current loss on the validation set is > 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is <= 0.01, the diff + * of validation error is compared to absolute tolerance which is + * validationTol * 0.01. + * Ignored when * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Since("1.2.0") @@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") ( // Optional boosting parameters @Since("1.2.0") @BeanProperty var numIterations: Int = 100, @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, - @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { + @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable { /** * Check validity of parameters. -- cgit v1.2.3