aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-02-24 15:13:22 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-02-24 15:13:22 -0800
commit2a0fe34891882e0fde1b5722d8227aa99acc0f1f (patch)
tree238c58f540e0b8c727e131b6359041d137c4e780 /mllib
parentda505e59274d1c838653c1109db65ad374e65304 (diff)
downloadspark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.tar.gz
spark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.tar.bz2
spark-2a0fe34891882e0fde1b5722d8227aa99acc0f1f.zip
[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation
One can early stop if the decrease in error rate is lesser than a certain tol or if the error increases if the training data is overfit. This introduces a new method runWithValidation which takes in a pair of RDD's , one for the training data and the other for the validation. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4677 from MechCoder/spark-5436 and squashes the following commits: 1bb21d4 [MechCoder] Combine regression and classification tests into a single one e4d799b [MechCoder] Addresses indentation and doc comments b48a70f [MechCoder] COSMIT b928a19 [MechCoder] Move validation while training section under usage tips fad9b6e [MechCoder] Made the following changes 1. Add section to documentation 2. Return corresponding to bestValidationError 3. Allow negative tolerance. 55e5c3b [MechCoder] One liner for prevValidateError 3e74372 [MechCoder] TST: Add test for classification 77549a9 [MechCoder] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala75
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala36
3 files changed, 111 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 61f6b1313f..b4466ff409 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+ case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+ GradientBoostedTrees.boost(remappedInput,
+ remappedInput, boostingStrategy, validate=false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
run(input.rdd)
}
-}
+ /**
+ * Method to validate a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param validationInput Validation dataset:
+ RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ Should be different from and follow the same distribution as input.
+ e.g., these two datasets could be created from an original dataset
+ by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def runWithValidation(
+ input: RDD[LabeledPoint],
+ validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case Regression => GradientBoostedTrees.boost(
+ input, validationInput, boostingStrategy, validate=true)
+ case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val remappedValidationInput = validationInput.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
+ validate=true)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
+ */
+ def runWithValidation(
+ input: JavaRDD[LabeledPoint],
+ validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ runWithValidation(input.rdd, validationInput.rdd)
+ }
+}
object GradientBoostedTrees extends Logging {
@@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
+ * @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
+ * @param validate whether or not to use the validation dataset.
* @return a gradient boosted trees model that can be used for prediction
*/
private def boost(
input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ validationInput: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy,
+ validate: Boolean): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
@@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging {
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
+ val validationTol = boostingStrategy.validationTol
treeStrategy.algo = Regression
treeStrategy.impurity = Variance
treeStrategy.assertValid()
@@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(0) = 1.0
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
+
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
+ var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
+ var bestM = 1
+
// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))
-
var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
@@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging {
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
+
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best validation error.
+ val currentValidateError = loss.computeError(partialModel, validationInput)
+ if (bestValidateError - currentValidateError < validationTol) {
+ return new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
+ }
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
@@ -191,4 +255,5 @@ object GradientBoostedTrees extends Logging {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index ed8e6a796f..664c8df019 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
+ * @param validationTol Useful when runWithValidation is used. If the error rate on the
+ * validation input between two iterations is less than the validationTol
+ * then stop. Ignored when [[run]] is used.
*/
@Experimental
case class BoostingStrategy(
@@ -42,7 +45,8 @@ case class BoostingStrategy(
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var numIterations: Int = 100,
- @BeanProperty var learningRate: Double = 0.1) extends Serializable {
+ @BeanProperty var learningRate: Double = 0.1,
+ @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
/**
* Check validity of parameters.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index bde47606eb..b437aeaaf0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,6 +158,40 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
+ val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ (algos zip losses) map {
+ case (algo, loss) => {
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ assert(gbtValidate.numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+ }
+ }
+ }
+
}
private object GradientBoostedTreesSuite {
@@ -166,4 +200,6 @@ private object GradientBoostedTreesSuite {
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+ val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
+ val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
}