diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2014-11-05 10:33:13 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-05 10:33:13 -0800 |
commit | 5b3b6f6f5f029164d7749366506e142b104c1d43 (patch) | |
tree | 78dc97e3f62d803e0f5cd3837d6a454f27e6e155 /mllib/src/main | |
parent | 5f13759d3642ea5b58c12a756e7125ac19aff10e (diff) | |
download | spark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.gz spark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.bz2 spark-5b3b6f6f5f029164d7749366506e142b104c1d43.zip |
[SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java
### Summary
* Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types.
* Added Scala and Java examples for GradientBoostedTrees
* small cleanups and fixes
### Details
GradientBoosting bug fixes (“bug” = bad default options)
* Force boostingStrategy.weakLearnerParams.algo = Regression
* Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance
* Only persist data if not yet persisted (since it causes an error if persisted twice)
BoostingStrategy
* numEstimators: renamed to numIterations
* removed subsamplingRate (duplicated by Strategy)
* removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type)
* Changed algo to var (not val) and added BeanProperty, with overload taking String argument
* Added assertValid() method
* Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy
Strategy (for DecisionTree)
* Changed algo to var (not val) and added BeanProperty, with overload taking String argument
* Added setCategoricalFeaturesInfo method taking Java Map.
* Cleaned up assertValid
* Changed val’s to def’s since parameters can now be changed.
CC: manishamde mengxr codedeft
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #3094 from jkbradley/gbt-api and squashes the following commits:
7a27e22 [Joseph K. Bradley] scalastyle fix
52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api
e9b8410 [Joseph K. Bradley] Summary of changes
Diffstat (limited to 'mllib/src/main')
3 files changed, 131 insertions, 167 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala index 1a847201ce..f729344a68 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -17,30 +17,49 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ - +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: - * A class that implements gradient boosting for regression and binary classification problems. + * A class that implements Stochastic Gradient Boosting + * for regression and binary classification problems. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * * @param boostingStrategy Parameters for the gradient boosting algorithm */ @Experimental class GradientBoosting ( private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + boostingStrategy.weakLearnerParams.algo = Regression + boostingStrategy.weakLearnerParams.impurity = impurity.Variance + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + boostingStrategy.weakLearnerParams.numClassesForClassification = + boostingStrategy.numClassesForClassification + + boostingStrategy.assertValid() + /** * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. @@ -51,6 +70,7 @@ class GradientBoosting ( algo match { case Regression => GradientBoosting.boost(input, boostingStrategy) case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoosting.boost(remappedInput, boostingStrategy) case _ => @@ -118,120 +138,32 @@ object GradientBoosting extends Logging { } /** - * Method to train a gradient boosting binary classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] */ - def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. - * @param loss Loss function used for minimization during gradient boosting. - * @param learningRate Learning rate for shrinking the contribution of each estimator. The - * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is - * supported.) - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: Map[Int, Int], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - val lossType = Losses.fromString(loss) - val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, - learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, - weakLearnerParams) - new GradientBoosting(boostingStrategy).train(input) + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + train(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] */ def trainClassifier( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainClassifier(input.rdd, boostingStrategy) } /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] */ def trainRegressor( - input: RDD[LabeledPoint], - numEstimators: Int, - loss: String, - learningRate: Double, - subsamplingRate: Double, - numClassesForClassification: Int, - categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], - weakLearnerParams: Strategy): WeightedEnsembleModel = { - trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, - numClassesForClassification, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - weakLearnerParams) + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + trainRegressor(input.rdd, boostingStrategy) } - /** * Internal method for performing regression using trees as base learners. * @param input training dataset @@ -247,15 +179,17 @@ object GradientBoosting extends Logging { timer.start("init") // Initialize gradient boosting parameters - val numEstimators = boostingStrategy.numEstimators - val baseLearners = new Array[DecisionTreeModel](numEstimators) - val baseLearnerWeights = new Array[Double](numEstimators) + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate val strategy = boostingStrategy.weakLearnerParams // Cache input - input.persist(StorageLevel.MEMORY_AND_DISK) + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } timer.stop("init") @@ -264,7 +198,7 @@ object GradientBoosting extends Logging { logDebug("##########") var data = input - // 1. Initialize tree + // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(strategy).train(data) baseLearners(0) = firstTreeModel @@ -280,7 +214,7 @@ object GradientBoosting extends Logging { point.features)) var m = 1 - while (m < numEstimators) { + while (m < numIterations) { timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) @@ -289,6 +223,9 @@ object GradientBoosting extends Logging { timer.stop(s"building tree $m") // Create partial model baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), @@ -305,8 +242,6 @@ object GradientBoosting extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - // 3. Output classifier new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 501d9ff9ea..abbda040bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,6 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** @@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @param numEstimators Number of estimators used in boosting stages. In other words, - * number of boosting iterations performed. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param subsamplingRate Fraction of the training data used for learning the decision tree. * @param numClassesForClassification Number of classes for classification. * (Ignored for regression.) + * This setting overrides any setting in [[weakLearnerParams]]. * Default value is 2 (binary classification). - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - algo: Algo, - @BeanProperty var numEstimators: Int, + @BeanProperty var algo: Algo, + @BeanProperty var numIterations: Int, @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var subsamplingRate: Double = 1.0, @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + - s"$learningRate.") - require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + - s"$learningRate.") - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo weakLearnerParams.numClassesForClassification = numClassesForClassification - weakLearnerParams.subsamplingRate = subsamplingRate + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification == 2) + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } } @Experimental @@ -82,28 +93,17 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: Algo): BoostingStrategy = { - val treeStrategy = defaultWeakLearnerParams(algo) + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy("Regression") + treeStrategy.maxDepth = 3 algo match { - case Classification => - new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) - case Regression => - new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case "Classification" => + new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + case "Regression" => + new BoostingStrategy(Algo.withName(algo), 100, SquaredError, + weakLearnerParams = treeStrategy) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } } - - /** - * Returns default configuration for the weak learner (decision tree) algorithm - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] - * @return Configuration for weak learner - */ - def defaultWeakLearnerParams(algo: Algo): Strategy = { - // Note: Regression tree used even for classification for GBT. - new Strategy(Regression, Variance, 3) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d09295c507..b5b1f82177 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class Strategy ( - val algo: Algo, + @BeanProperty var algo: Algo, @BeanProperty var impurity: Impurity, @BeanProperty var maxDepth: Int, @BeanProperty var numClassesForClassification: Int = 2, @@ -85,17 +85,9 @@ class Strategy ( @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -113,6 +105,23 @@ class Strategy ( } /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + + /** * Check validity of parameters. * Throws exception if invalid. */ @@ -143,6 +152,26 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } +} + +@Experimental +object Strategy { + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } |