aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-20 17:14:09 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-03-20 17:14:09 -0700
commit25e271d9fbb3394931d23822a1b2020e9d9b46b3 (patch)
tree4fbb4b1b3b4406c2d1d54470aed8d3ee7968e0de /mllib/src/test
parenta95043b1780bfde556db2dcc01511e40a12498dd (diff)
downloadspark-25e271d9fbb3394931d23822a1b2020e9d9b46b3.tar.gz
spark-25e271d9fbb3394931d23822a1b2020e9d9b46b3.tar.bz2
spark-25e271d9fbb3394931d23822a1b2020e9d9b46b3.zip
[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve
Added evaluateEachIteration to allow the user to manually extract the error for each iteration of GradientBoosting. The internal optimisation can be dealt with later. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4906 from MechCoder/spark-6025 and squashes the following commits: 67146ab [MechCoder] Minor 352001f [MechCoder] Minor 6e8aa10 [MechCoder] Made the following changes Used mapPartition instead of map Refactored computeError and unpersisted broadcast variables bc99ac6 [MechCoder] Refactor the method and stuff dbda033 [MechCoder] [SPARK-6025] Add helper method evaluateEachIteration to extract learning curve
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala16
1 files changed, 14 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index b437aeaaf0..55b0bac7d4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
- assert(gbtValidate.numTrees !== numIterations)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
}