diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-02-26 10:51:47 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-02-26 10:51:47 -0800 |
commit | cfff397f0adb27ca102cca43a7696e9fb1819ee0 (patch) | |
tree | 3388fa584b04c7642dbdfa365cc3ccb4b565a041 | |
parent | 2358657547016d647cdd2e2d363426fcd8d3e9ff (diff) | |
download | spark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.tar.gz spark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.tar.bz2 spark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.zip |
[SPARK-6004][MLlib] Pick the best model when training GradientBoostedTrees with validation
Since the validation error does not change monotonically, in practice, it should be proper to pick the best model when training GradientBoostedTrees with validation instead of stopping it early.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #4763 from viirya/gbt_record_model and squashes the following commits:
452e049 [Liang-Chi Hsieh] Address comment.
ea2fae2 [Liang-Chi Hsieh] Pick the best model when training GradientBoostedTrees with validation.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index b4466ff409..a9c93e181e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -251,9 +251,15 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } } |