aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala45
1 files changed, 1 insertions, 44 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 747c267b4f..c61f89322d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("runWithValidation stops early and performs better on a validation dataset") {
- // Set numIterations large enough so that it stops early.
- val numIterations = 20
- val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
- val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
-
- val algos = Array(Regression, Regression, Classification)
- val losses = Array(SquaredError, AbsoluteError, LogLoss)
- algos.zip(losses).foreach { case (algo, loss) =>
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
- }
- }
- }
-
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
-private object GradientBoostedTreesSuite {
+private[spark] object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))