diff options
9 files changed, 66 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 5a8845fdb6..c31df3aa18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") ( override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) @Since("1.4.0") - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: @@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 428bc7a6d8..fa7cc436f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeRegressionModel = { val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 091e1d5fa8..da5b77e8fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) @Since("1.4.0") - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: @Since("1.4.0") @@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b9acc66472..1c8a9b4dfe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ - def run(input: RDD[LabeledPoint], - boostingStrategy: OldBoostingStrategy - ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + def run( + input: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, + seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging { * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], - boostingStrategy: OldBoostingStrategy - ): (Array[DecisionTreeRegressionModel], Array[Double]) = { + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true) + validate = true, seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging { * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters * @param validate whether or not to use the validation dataset. + * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) */ @@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging { input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, - validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = { + validate: Boolean, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") - val firstTree = new DecisionTreeRegressor() + val firstTree = new DecisionTreeRegressor().setSeed(seed) val firstTreeModel = firstTree.train(input, treeStrategy) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel @@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val dt = new DecisionTreeRegressor() + val dt = new DecisionTreeRegressor().setSeed(seed + m) val model = dt.train(data, treeStrategy) timer.stop(s"building tree $m") // Update partial model diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8f02e098ac..c40d5e3fff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom * @param strategy The configuration parameters for the tree algorithm which specify the type * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. + * @param seed Random seed. */ @Since("1.0.0") -class DecisionTree @Since("1.0.0") (private val strategy: Strategy) +class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int) extends Serializable with Logging { + /** + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of decision tree (classification or regression), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ + @Since("1.0.0") + def this(strategy: Strategy) = this(strategy, seed = 0) + strategy.assertValid() /** @@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Note: random seed will not be used since numTrees = 1. - val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = seed) val rfModel = rf.run(input) rfModel.trees(0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index eb40fb0391..d166dc7905 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -47,12 +47,21 @@ import org.apache.spark.storage.StorageLevel * for other loss functions. * * @param boostingStrategy Parameters for the gradient boosting algorithm. + * @param seed Random seed. */ @Since("1.2.0") -class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees private[spark] ( + private val boostingStrategy: BoostingStrategy, + private val seed: Int) extends Serializable with Logging { /** + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ + @Since("1.2.0") + def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0) + + /** * Method to train a gradient boosting model * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val algo = boostingStrategy.treeStrategy.algo algo match { case Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, + seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val algo = boostingStrategy.treeStrategy.algo algo match { case Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true) + validate = true, seed) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging { def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { - new GradientBoostedTrees(boostingStrategy).run(input) + new GradientBoostedTrees(boostingStrategy, seed = 0).run(input) } /** @@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging { * @param validationInput Validation dataset, ignored if validate is set to false. * @param boostingStrategy Boosting parameters. * @param validate Whether or not to use the validation dataset. + * @param seed Random seed. * @return GradientBoostedTreesModel that can be used for prediction. */ private def boost( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint], boostingStrategy: BoostingStrategy, - validate: Boolean): GradientBoostedTreesModel = { + validate: Boolean, + seed: Int): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(input) + val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight @@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging { logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val model = new DecisionTree(treeStrategy).run(data) + val model = new DecisionTree(treeStrategy, seed + m).run(data) timer.stop(s"building tree $m") // Update partial model baseLearners(m) = model diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 29efd675ab..f3680ed044 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType("logistic") .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) // copied model must have the same parent. @@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index db68606397..84148a8a4a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType(loss) .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) sc.checkpointDir = None @@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 58828b3af9..747c267b4f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext categoricalFeaturesInfo = Map.empty) val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) + val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0) .runWithValidation(trainRdd, validateRdd) val numTrees = gbtValidate.numTrees assert(numTrees !== numIterations) // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd) val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) |