aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala4
9 files changed, 66 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 5a8845fdb6..c31df3aa18 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") (
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
- override def setSeed(value: Long): this.type = {
- logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
- super.setSeed(value)
- }
+ override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
+ $(seed))
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 428bc7a6d8..fa7cc436f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, parentUID = Some(uid))
+ seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 091e1d5fa8..da5b77e8fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
- override def setSeed(value: Long): this.type = {
- logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
- super.setSeed(value)
- }
+ override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@Since("1.4.0")
@@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
+ $(seed))
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index b9acc66472..1c8a9b4dfe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
- def run(input: RDD[LabeledPoint],
- boostingStrategy: OldBoostingStrategy
- ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ def run(
+ input: RDD[LabeledPoint],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
@@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging {
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
- boostingStrategy: OldBoostingStrategy
- ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true)
+ validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging {
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
@@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging {
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
- validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ validate: Boolean,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
- val firstTree = new DecisionTreeRegressor()
+ val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
@@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
- val dt = new DecisionTreeRegressor()
+ val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
// Update partial model
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 8f02e098ac..c40d5e3fff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
+ * @param seed Random seed.
*/
@Since("1.0.0")
-class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
+class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
extends Serializable with Logging {
+ /**
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of decision tree (classification or regression), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+ @Since("1.0.0")
+ def this(strategy: Strategy) = this(strategy, seed = 0)
+
strategy.assertValid()
/**
@@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
- // Note: random seed will not be used since numTrees = 1.
- val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
+ val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index eb40fb0391..d166dc7905 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -47,12 +47,21 @@ import org.apache.spark.storage.StorageLevel
* for other loss functions.
*
* @param boostingStrategy Parameters for the gradient boosting algorithm.
+ * @param seed Random seed.
*/
@Since("1.2.0")
-class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
+class GradientBoostedTrees private[spark] (
+ private val boostingStrategy: BoostingStrategy,
+ private val seed: Int)
extends Serializable with Logging {
/**
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ */
+ @Since("1.2.0")
+ def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0)
+
+ /**
* Method to train a gradient boosting model
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true)
+ validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging {
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
- new GradientBoostedTrees(boostingStrategy).run(input)
+ new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
}
/**
@@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging {
* @param validationInput Validation dataset, ignored if validate is set to false.
* @param boostingStrategy Boosting parameters.
* @param validate Whether or not to use the validation dataset.
+ * @param seed Random seed.
* @return GradientBoostedTreesModel that can be used for prediction.
*/
private def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
- validate: Boolean): GradientBoostedTreesModel = {
+ validate: Boolean,
+ seed: Int): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(input)
+ val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
@@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
- val model = new DecisionTree(treeStrategy).run(data)
+ val model = new DecisionTree(treeStrategy, seed + m).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 29efd675ab..f3680ed044 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType("logistic")
.setMaxIter(maxIter)
.setStepSize(learningRate)
+ .setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
+ .setSeed(123)
val model = gbt.fit(df)
// copied model must have the same parent.
@@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldBoostingStrategy =
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
- val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index db68606397..84148a8a4a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType(loss)
.setMaxIter(maxIter)
.setStepSize(learningRate)
+ .setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
+ .setSeed(123)
val model = gbt.fit(df)
sc.checkpointDir = None
@@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite {
categoricalFeatures: Map[Int, Int]): Unit = {
val numFeatures = data.first().features.size
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 58828b3af9..747c267b4f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
.runWithValidation(trainRdd, validateRdd)
val numTrees = gbtValidate.numTrees
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+ val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))