aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-07-30 16:04:23 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-30 16:04:23 -0700
commitbe7be6d4c7d978c20e601d1f5f56ecb3479814cb (patch)
tree61bfc4231f86261a8ecb43b5f68fa4ad075b7037 /mllib
parent7f7a319c4ce07f07a6bd68100cf0a4f1da66269e (diff)
downloadspark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.tar.gz
spark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.tar.bz2
spark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.zip
[SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs
Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits: 3fbd7ba [Joseph K. Bradley] tiny fix b3e160c [Joseph K. Bradley] unset checkpoint dir after test 9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala48
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala79
6 files changed, 114 insertions, 57 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 9dbec41efe..d6f8b29a43 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
this.checkpointInterval = lda.getCheckpointInterval
this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
checkpointInterval, graph.vertices.sparkContext)
+ this.graphCheckpointer.update(this.graph)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a835f96d5d..9ce6faa137 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
false
}
+ // Prepare periodic checkpointers
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+ val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+
timer.stop("init")
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
- var data = input
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ val firstTreeModel = new DecisionTree(treeStrategy).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+ if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
- // pseudo-residual for second iteration
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
var m = 1
- while (m < numIterations) {
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ // Update data with pseudo-residuals
+ val data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
- // Create partial model
+ // Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
- // Note: A model of type regression is used since we require raw prediction
- val partialModel = new GradientBoostedTreesModel(
- Regression, baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1))
predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
if (validate) {
@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+ validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
- return new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
+ doneLearning = true
} else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
+ bestValidateError = currentValidateError
+ bestM = m + 1
}
}
- // Update data with pseudo-residuals
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
m += 1
}
@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ predErrorCheckpointer.deleteAllCheckpoints()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist()
if (validate) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 2d6b01524f..9fd30c9b56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol
- * then stop. Ignored when [[run]] is used.
+ * then stop. Ignored when
+ * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Experimental
case class BoostingStrategy(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 82c345491b..a7bc77965f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setLossType("logistic")
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 9682edcd9b..dbdce0c9de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predictions.min() < -1)
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val df = sqlContext.createDataFrame(data)
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 2521b33421..6fc9e8df62 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
- (algos zip losses) map {
- case (algo, loss) => {
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
+ val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
}
private object GradientBoostedTreesSuite {