From 69bc2c17f1ca047d4915a4791b624d60c5943dc8 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 23 Mar 2016 15:08:47 -0700 Subject: [SPARK-13952][ML] Add random seed to GBT ## What changes were proposed in this pull request? `GBTClassifier` and `GBTRegressor` should use random seed for reproducible results. Because of the nature of current unit tests, which compare GBTs in ML and GBTs in MLlib for equality, I also added a random seed to MLlib GBT algorithm. I made alternate constructors in `mllib.tree.GradientBoostedTrees` to accept a random seed, but left them as private so as to not change the API unnecessarily. ## How was this patch tested? Existing unit tests verify that functionality did not change. Other ML algorithms do not seem to have unit tests that directly test the functionality of random seeding, but reproducibility with seeding for GBTs is effectively verified in existing tests. I can add more tests if needed. Author: sethah Closes #11903 from sethah/SPARK-13952. --- .../spark/ml/classification/GBTClassifier.scala | 8 +++--- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../apache/spark/ml/regression/GBTRegressor.scala | 8 +++--- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 30 +++++++++++++--------- .../org/apache/spark/mllib/tree/DecisionTree.scala | 15 ++++++++--- .../spark/mllib/tree/GradientBoostedTrees.scala | 30 +++++++++++++++------- .../ml/classification/GBTClassifierSuite.scala | 4 ++- .../spark/ml/regression/GBTRegressorSuite.scala | 4 ++- .../mllib/tree/GradientBoostedTreesSuite.scala | 4 +-- 9 files changed, 66 insertions(+), 39 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 5a8845fdb6..c31df3aa18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") ( override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) @Since("1.4.0") - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: @@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 428bc7a6d8..fa7cc436f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeRegressionModel = { val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 091e1d5fa8..da5b77e8fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) @Since("1.4.0") - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: @Since("1.4.0") @@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b9acc66472..1c8a9b4dfe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ - def run(input: RDD[LabeledPoint], - boostingStrategy: OldBoostingStrategy - ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + def run( + input: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, + seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging { * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], - boostingStrategy: OldBoostingStrategy - ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true) + validate = true, seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging { * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters * @param validate whether or not to use the validation dataset. + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ @@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging { input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = { + validate: Boolean, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") - val firstTree = new DecisionTreeRegressor() + val firstTree = new DecisionTreeRegressor().setSeed(seed) val firstTreeModel = firstTree.train(input, treeStrategy) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel @@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val dt = new DecisionTreeRegressor() + val dt = new DecisionTreeRegressor().setSeed(seed + m) val model = dt.train(data, treeStrategy) timer.stop(s"building tree $m") // Update partial model diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8f02e098ac..c40d5e3fff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom * @param strategy The configuration parameters for the tree algorithm which specify the type * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. + * @param seed Random seed. */ @Since("1.0.0") -class DecisionTree @Since("1.0.0") (private val strategy: Strategy) +class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int) extends Serializable with Logging { + /** + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of decision tree (classification or regression), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ + @Since("1.0.0") + def this(strategy: Strategy) = this(strategy, seed = 0) + strategy.assertValid() /** @@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Note: random seed will not be used since numTrees = 1. - val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = seed) val rfModel = rf.run(input) rfModel.trees(0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index eb40fb0391..d166dc7905 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -47,11 +47,20 @@ import org.apache.spark.storage.StorageLevel * for other loss functions. * * @param boostingStrategy Parameters for the gradient boosting algorithm. + * @param seed Random seed. */ @Since("1.2.0") -class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees private[spark] ( + private val boostingStrategy: BoostingStrategy, + private val seed: Int) extends Serializable with Logging { + /** + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ + @Since("1.2.0") + def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0) + /** * Method to train a gradient boosting model * @@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val algo = boostingStrategy.treeStrategy.algo algo match { case Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, + seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val algo = boostingStrategy.treeStrategy.algo algo match { case Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true) + validate = true, seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging { def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { - new GradientBoostedTrees(boostingStrategy).run(input) + new GradientBoostedTrees(boostingStrategy, seed = 0).run(input) } /** @@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging { * @param validationInput Validation dataset, ignored if validate is set to false. * @param boostingStrategy Boosting parameters. * @param validate Whether or not to use the validation dataset. + * @param seed Random seed. * @return GradientBoostedTreesModel that can be used for prediction. */ private def boost( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: BoostingStrategy, - validate: Boolean): GradientBoostedTreesModel = { + validate: Boolean, + seed: Int): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(input) + val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight @@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val model = new DecisionTree(treeStrategy).run(data) + val model = new DecisionTree(treeStrategy, seed + m).run(data) timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 29efd675ab..f3680ed044 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType("logistic") .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) // copied model must have the same parent. @@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index db68606397..84148a8a4a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType(loss) .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) sc.checkpointDir = None @@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 58828b3af9..747c267b4f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext categoricalFeaturesInfo = Map.empty) val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) + val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0) .runWithValidation(trainRdd, validateRdd) val numTrees = gbtValidate.numTrees assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd) val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) -- cgit v1.2.3