aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-02-24 15:13:22 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-02-24 15:13:22 -0800
commit2a0fe34891882e0fde1b5722d8227aa99acc0f1f (patch)
tree238c58f540e0b8c727e131b6359041d137c4e780 /mllib/src/test
parentda505e59274d1c838653c1109db65ad374e65304 (diff)
downloadspark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.tar.gz
spark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.tar.bz2
spark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.zip
[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation
One can early stop if the decrease in error rate is lesser than a certain tol or if the error increases if the training data is overfit. This introduces a new method runWithValidation which takes in a pair of RDD's , one for the training data and the other for the validation. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4677 from MechCoder/spark-5436 and squashes the following commits: 1bb21d4 [MechCoder] Combine regression and classification tests into a single one e4d799b [MechCoder] Addresses indentation and doc comments b48a70f [MechCoder] COSMIT b928a19 [MechCoder] Move validation while training section under usage tips fad9b6e [MechCoder] Made the following changes 1. Add section to documentation 2. Return corresponding to bestValidationError 3. Allow negative tolerance. 55e5c3b [MechCoder] One liner for prevValidateError 3e74372 [MechCoder] TST: Add test for classification 77549a9 [MechCoder] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala36
1 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index bde47606eb..b437aeaaf0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,6 +158,40 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
+ val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ (algos zip losses) map {
+ case (algo, loss) => {
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ assert(gbtValidate.numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+ }
+ }
+ }
+
}
private object GradientBoostedTreesSuite {
@@ -166,4 +200,6 @@ private object GradientBoostedTreesSuite {
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+ val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
+ val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
}