aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 15:08:47 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 15:08:47 -0700
commit69bc2c17f1ca047d4915a4791b624d60c5943dc8 (patch)
treef59f8e0aae9421eab43bd2509b32e222ea104fa4 /mllib
parent5dfc01976bb0d72489620b4f32cc12d620bb6260 (diff)
downloadspark-69bc2c17f1ca047d4915a4791b624d60c5943dc8.tar.gz
spark-69bc2c17f1ca047d4915a4791b624d60c5943dc8.tar.bz2
spark-69bc2c17f1ca047d4915a4791b624d60c5943dc8.zip
[SPARK-13952][ML] Add random seed to GBT
## What changes were proposed in this pull request? `GBTClassifier` and `GBTRegressor` should use random seed for reproducible results. Because of the nature of current unit tests, which compare GBTs in ML and GBTs in MLlib for equality, I also added a random seed to MLlib GBT algorithm. I made alternate constructors in `mllib.tree.GradientBoostedTrees` to accept a random seed, but left them as private so as to not change the API unnecessarily. ## How was this patch tested? Existing unit tests verify that functionality did not change. Other ML algorithms do not seem to have unit tests that directly test the functionality of random seeding, but reproducibility with seeding for GBTs is effectively verified in existing tests. I can add more tests if needed. Author: sethah <seth.hendrickson16@gmail.com> Closes #11903 from sethah/SPARK-13952.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala4
9 files changed, 66 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 5a8845fdb6..c31df3aa18 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") (
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
- override def setSeed(value: Long): this.type = {
- logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
- super.setSeed(value)
- }
+ override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
+ $(seed))
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 428bc7a6d8..fa7cc436f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, parentUID = Some(uid))
+ seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 091e1d5fa8..da5b77e8fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
- override def setSeed(value: Long): this.type = {
- logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
- super.setSeed(value)
- }
+ override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from GBTParams:
@Since("1.4.0")
@@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
+ val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
+ $(seed))
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index b9acc66472..1c8a9b4dfe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
- def run(input: RDD[LabeledPoint],
- boostingStrategy: OldBoostingStrategy
- ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ def run(
+ input: RDD[LabeledPoint],
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
@@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging {
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
- boostingStrategy: OldBoostingStrategy
- ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ boostingStrategy: OldBoostingStrategy,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true)
+ validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging {
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
+ * @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
@@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging {
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
- validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+ validate: Boolean,
+ seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
- val firstTree = new DecisionTreeRegressor()
+ val firstTree = new DecisionTreeRegressor().setSeed(seed)
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
@@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
- val dt = new DecisionTreeRegressor()
+ val dt = new DecisionTreeRegressor().setSeed(seed + m)
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
// Update partial model
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 8f02e098ac..c40d5e3fff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of decision tree (classification or regression), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
+ * @param seed Random seed.
*/
@Since("1.0.0")
-class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
+class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
extends Serializable with Logging {
+ /**
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of decision tree (classification or regression), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+ @Since("1.0.0")
+ def this(strategy: Strategy) = this(strategy, seed = 0)
+
strategy.assertValid()
/**
@@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
- // Note: random seed will not be used since numTrees = 1.
- val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
+ val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index eb40fb0391..d166dc7905 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -47,12 +47,21 @@ import org.apache.spark.storage.StorageLevel
* for other loss functions.
*
* @param boostingStrategy Parameters for the gradient boosting algorithm.
+ * @param seed Random seed.
*/
@Since("1.2.0")
-class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
+class GradientBoostedTrees private[spark] (
+ private val boostingStrategy: BoostingStrategy,
+ private val seed: Int)
extends Serializable with Logging {
/**
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ */
+ @Since("1.2.0")
+ def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0)
+
+ /**
* Method to train a gradient boosting model
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
+ seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val algo = boostingStrategy.treeStrategy.algo
algo match {
case Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true)
+ validate = true, seed)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging {
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
- new GradientBoostedTrees(boostingStrategy).run(input)
+ new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
}
/**
@@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging {
* @param validationInput Validation dataset, ignored if validate is set to false.
* @param boostingStrategy Boosting parameters.
* @param validate Whether or not to use the validation dataset.
+ * @param seed Random seed.
* @return GradientBoostedTreesModel that can be used for prediction.
*/
private def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy,
- validate: Boolean): GradientBoostedTreesModel = {
+ validate: Boolean,
+ seed: Int): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
@@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging {
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(input)
+ val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
@@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
- val model = new DecisionTree(treeStrategy).run(data)
+ val model = new DecisionTree(treeStrategy, seed + m).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 29efd675ab..f3680ed044 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType("logistic")
.setMaxIter(maxIter)
.setStepSize(learningRate)
+ .setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
+ .setSeed(123)
val model = gbt.fit(df)
// copied model must have the same parent.
@@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldBoostingStrategy =
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
- val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index db68606397..84148a8a4a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setLossType(loss)
.setMaxIter(maxIter)
.setStepSize(learningRate)
+ .setSeed(123)
compareAPIs(data, None, gbt, categoricalFeatures)
}
}
@@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
+ .setSeed(123)
val model = gbt.fit(df)
sc.checkpointDir = None
@@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite {
categoricalFeatures: Map[Int, Int]): Unit = {
val numFeatures = data.first().features.size
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 58828b3af9..747c267b4f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
categoricalFeaturesInfo = Map.empty)
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
.runWithValidation(trainRdd, validateRdd)
val numTrees = gbtValidate.numTrees
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+ val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))