aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-04-06 17:13:34 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 17:13:34 -0700
commitbb873754b4700104755ab969694bf30945557dc3 (patch)
tree02b5b39b530827fea0871ade32e4b8927edb7e9a /mllib/src
parent864d1b4d665e2cc1d40b53502a4ddf26c1fbfc1d (diff)
downloadspark-bb873754b4700104755ab969694bf30945557dc3.tar.gz
spark-bb873754b4700104755ab969694bf30945557dc3.tar.bz2
spark-bb873754b4700104755ab969694bf30945557dc3.zip
[SPARK-12382][ML] Remove mllib GBT implementation and wrap ml
## What changes were proposed in this pull request? This patch removes the implementation of gradient boosted trees in mllib/tree/GradientBoostedTrees.scala and changes mllib GBTs to call the implementation in spark.ML. Primary changes: * Removed `boost` method in mllib GradientBoostedTrees.scala * Created new test suite GradientBoostedTreesSuite in ML, which contains unit tests that were specific to GBT internals from mllib Other changes: * Added an `updatePrediction` method in GradientBoostedTrees package. This method is added to provide consistency for methods that build predictions from boosted models. There are several methods that hard code the method of predicting as: sum_{i=1}^{numTrees} (treePrediction*treeWeight). Calling this function ensures that test methods that check accuracy use the same prediction method that the algorithm uses during training * Added methods that were previously only used in testing, but were public methods, to GradientBoostedTrees. This includes `computeError` (previously part of `Loss` trait) and `evaluateEachIteration`. These are used in the new spark.ML unit tests. They are left in mllib as well so as to not break the API. ## How was this patch tested? Existing unit tests which compare ML and MLlib ensure that mllib GBTs have not changed. Only a single unit test was moved to ML, which verifies that `runWithValidation` performs as expected. Author: sethah <seth.hendrickson16@gmail.com> Closes #12050 from sethah/SPARK-12382.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala115
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala182
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala45
6 files changed, 207 insertions, 226 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 0749d93b7d..d365655674 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
-import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -30,7 +29,24 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-private[ml] object GradientBoostedTrees extends Logging {
+
+/**
+ * A package that implements
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]]
+ * for regression and binary classification.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ */
+private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
@@ -107,7 +123,7 @@ private[ml] object GradientBoostedTrees extends Logging {
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
- val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
+ val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
@@ -133,7 +149,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
- val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
+ val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
@@ -142,6 +158,97 @@ private[ml] object GradientBoostedTrees extends Logging {
}
/**
+ * Add prediction from a new boosting iteration to an existing prediction.
+ *
+ * @param features Vector of features representing a single data point.
+ * @param prediction The existing prediction.
+ * @param tree New Decision Tree model.
+ * @param weight Tree weight.
+ * @return Updated prediction.
+ */
+ def updatePrediction(
+ features: Vector,
+ prediction: Double,
+ tree: DecisionTreeRegressionModel,
+ weight: Double): Double = {
+ prediction + tree.rootNode.predictImpl(features).prediction * weight
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @return Measure of model error on data
+ */
+ def computeError(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss): Double = {
+ data.map { lp =>
+ val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
+ updatePrediction(lp.features, acc, model, weight)
+ }
+ loss.computeError(predicted, lp.label)
+ }.mean()
+ }
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @param algo algorithm for the ensemble, either Classification or Regression
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss,
+ algo: OldAlgo.Value): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+ val localTreeWeights = treeWeights
+
+ var predictionAndError = computeInitialPredictionAndError(
+ remappedData, localTreeWeights(0), trees(0), loss)
+
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ val broadcastTrees = sc.broadcast(trees)
+ (1 until numIterations).foreach { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = localTreeWeights(nTree)
+ iter.map { case (point, (pred, error)) =>
+ val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight)
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ evaluationArray
+ }
+
+ /**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index db0ff28d82..c4ab673d9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -133,8 +133,8 @@ private[ml] object TreeEnsembleModel {
* following the explanation of Gini importance from "Random Forests" documentation
* by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
*
- * For collections of trees, including boosting and bagging, Hastie et al.
- * propose to use the average of single tree importances across all trees in the ensemble.
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
*
* This feature importance is calculated as follows:
* - Average over trees:
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 0f0c6b466d..7fe60e2d99 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,15 +20,11 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.tree.impl.TimeTracker
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.impurity.Variance
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
/**
* A class that implements
@@ -70,17 +66,8 @@ class GradientBoostedTrees private[spark] (
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
- seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -107,20 +94,9 @@ class GradientBoostedTrees private[spark] (
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- val remappedValidationInput = validationInput.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true, seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy,
+ seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -162,148 +138,4 @@ object GradientBoostedTrees extends Logging {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
train(input.rdd, boostingStrategy)
}
-
- /**
- * Internal method for performing regression using trees as base learners.
- *
- * @param input Training dataset.
- * @param validationInput Validation dataset, ignored if validate is set to false.
- * @param boostingStrategy Boosting parameters.
- * @param validate Whether or not to use the validation dataset.
- * @param seed Random seed.
- * @return GradientBoostedTreesModel that can be used for prediction.
- */
- private def boost(
- input: RDD[LabeledPoint],
- validationInput: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy,
- validate: Boolean,
- seed: Int): GradientBoostedTreesModel = {
- val timer = new TimeTracker()
- timer.start("total")
- timer.start("init")
-
- boostingStrategy.assertValid()
-
- // Initialize gradient boosting parameters
- val numIterations = boostingStrategy.numIterations
- val baseLearners = new Array[DecisionTreeModel](numIterations)
- val baseLearnerWeights = new Array[Double](numIterations)
- val loss = boostingStrategy.loss
- val learningRate = boostingStrategy.learningRate
- // Prepare strategy for individual trees, which use regression with variance impurity.
- val treeStrategy = boostingStrategy.treeStrategy.copy
- val validationTol = boostingStrategy.validationTol
- treeStrategy.algo = Regression
- treeStrategy.impurity = Variance
- treeStrategy.assertValid()
-
- // Cache input
- val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
- input.persist(StorageLevel.MEMORY_AND_DISK)
- true
- } else {
- false
- }
-
- // Prepare periodic checkpointers
- val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
- val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
-
- timer.stop("init")
-
- logDebug("##########")
- logDebug("Building tree 0")
- logDebug("##########")
-
- // Initialize tree
- timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
- val firstTreeWeight = 1.0
- baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = firstTreeWeight
-
- var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- // Note: A model of type regression is used since we require raw prediction
- timer.stop("building tree 0")
-
- var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
- if (validate) validatePredErrorCheckpointer.update(validatePredError)
- var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
- var bestM = 1
-
- var m = 1
- var doneLearning = false
- while (m < numIterations && !doneLearning) {
- // Update data with pseudo-residuals
- val data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
- timer.start(s"building tree $m")
- logDebug("###################################################")
- logDebug("Gradient boosting tree iteration " + m)
- logDebug("###################################################")
- val model = new DecisionTree(treeStrategy, seed + m).run(data)
- timer.stop(s"building tree $m")
- // Update partial model
- baseLearners(m) = model
- // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
- // Technically, the weight should be optimized for the particular loss.
- // However, the behavior should be reasonable, though not optimal.
- baseLearnerWeights(m) = learningRate
-
- predError = GradientBoostedTreesModel.updatePredictionError(
- input, predError, baseLearnerWeights(m), baseLearners(m), loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- if (validate) {
- // Stop training early if
- // 1. Reduction in error is less than the validationTol or
- // 2. If the error increases, that is if the model is overfit.
- // We want the model returned corresponding to the best validation error.
-
- validatePredError = GradientBoostedTreesModel.updatePredictionError(
- validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
- validatePredErrorCheckpointer.update(validatePredError)
- val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol * Math.max(
- currentValidateError, 0.01)) {
- doneLearning = true
- } else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
- }
- }
- m += 1
- }
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- predErrorCheckpointer.deleteAllCheckpoints()
- validatePredErrorCheckpointer.deleteAllCheckpoints()
- if (persistedInput) input.unpersist()
-
- if (validate) {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
- } else {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
- }
- }
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 914818f41f..3c11631f98 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -53,7 +53,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
- test("Regression with continuous features: SquaredError") {
+ test("Regression with continuous features") {
val categoricalFeatures = Map.empty[Int, Int]
GBTRegressor.supportedLossTypes.foreach { loss =>
testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
new file mode 100644
index 0000000000..fecf372c3d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suite for [[GradientBoostedTrees]].
+ */
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
+ val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
+ val trainDF = sqlContext.createDataFrame(trainRdd)
+ val validateDF = sqlContext.createDataFrame(validateRdd)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val (validateTrees, validateTreeWeights) = GradientBoostedTrees
+ .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
+ val numTrees = validateTrees.length
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(remappedRdd, validateTrees,
+ validateTreeWeights, loss))
+ } else {
+ (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(validateRdd, validateTrees,
+ validateTreeWeights, loss))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 747c267b4f..c61f89322d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("runWithValidation stops early and performs better on a validation dataset") {
- // Set numIterations large enough so that it stops early.
- val numIterations = 20
- val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
- val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
-
- val algos = Array(Regression, Regression, Classification)
- val losses = Array(SquaredError, AbsoluteError, LogLoss)
- algos.zip(losses).foreach { case (algo, loss) =>
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
- }
- }
- }
-
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
-private object GradientBoostedTreesSuite {
+private[spark] object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))