aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala12
1 files changed, 9 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index b4466ff409..a9c93e181e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -251,9 +251,15 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ if (validate) {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ }
}
}