diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-07-30 16:04:23 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-30 16:04:23 -0700 |
commit | be7be6d4c7d978c20e601d1f5f56ecb3479814cb (patch) | |
tree | 61bfc4231f86261a8ecb43b5f68fa4ad075b7037 /mllib/src/test/scala | |
parent | 7f7a319c4ce07f07a6bd68100cf0a4f1da66269e (diff) | |
download | spark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.tar.gz spark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.tar.bz2 spark-be7be6d4c7d978c20e601d1f5f56ecb3479814cb.zip |
[SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs
Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits:
3fbd7ba [Joseph K. Bradley] tiny fix
b3e160c [Joseph K. Bradley] unset checkpoint dir after test
9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs
Diffstat (limited to 'mllib/src/test/scala')
3 files changed, 86 insertions, 33 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491b..a7bc77965f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9b..dbdce0c9de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b33421..6fc9e8df62 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { |