aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-02-26 10:51:47 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-02-26 10:51:47 -0800
commitcfff397f0adb27ca102cca43a7696e9fb1819ee0 (patch)
tree3388fa584b04c7642dbdfa365cc3ccb4b565a041 /mllib
parent2358657547016d647cdd2e2d363426fcd8d3e9ff (diff)
downloadspark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.tar.gz
spark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.tar.bz2
spark-cfff397f0adb27ca102cca43a7696e9fb1819ee0.zip
[SPARK-6004][MLlib] Pick the best model when training GradientBoostedTrees with validation
Since the validation error does not change monotonically, in practice, it should be proper to pick the best model when training GradientBoostedTrees with validation instead of stopping it early. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4763 from viirya/gbt_record_model and squashes the following commits: 452e049 [Liang-Chi Hsieh] Address comment. ea2fae2 [Liang-Chi Hsieh] Pick the best model when training GradientBoostedTrees with validation.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala12
1 files changed, 9 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index b4466ff409..a9c93e181e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -251,9 +251,15 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ if (validate) {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ }
}
}