aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala15
2 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 95ed48cea6..66a07e3136 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging {
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol) {
+ if (bestValidateError - currentValidateError < validationTol * Math.max(
+ currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index b5c72fba3e..fc13bcfd8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param validationTol Useful when runWithValidation is used. If the error rate on the
- * validation input between two iterations is less than the validationTol
- * then stop. Ignored when
+ * @param validationTol validationTol is a condition which decides iteration termination when
+ * runWithValidation is used.
+ * The end of iteration is decided based on below logic:
+ * If the current loss on the validation set is > 0.01, the diff
+ * of validation error is compared to relative tolerance which is
+ * validationTol * (current loss on the validation set).
+ * If the current loss on the validation set is <= 0.01, the diff
+ * of validation error is compared to absolute tolerance which is
+ * validationTol * 0.01.
+ * Ignored when
* [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Since("1.2.0")
@@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") (
// Optional boosting parameters
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
- @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
+ @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable {
/**
* Check validity of parameters.