aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-11-05 10:33:13 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-05 10:33:13 -0800
commit5b3b6f6f5f029164d7749366506e142b104c1d43 (patch)
tree78dc97e3f62d803e0f5cd3837d6a454f27e6e155 /mllib
parent5f13759d3642ea5b58c12a756e7125ac19aff10e (diff)
downloadspark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.gz
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.tar.bz2
spark-5b3b6f6f5f029164d7749366506e142b104c1d43.zip
[SPARK-4197] [mllib] GradientBoosting API cleanup and examples in Scala, Java
### Summary * Made it easier to construct default Strategy and BoostingStrategy and to set parameters using simple types. * Added Scala and Java examples for GradientBoostedTrees * small cleanups and fixes ### Details GradientBoosting bug fixes (“bug” = bad default options) * Force boostingStrategy.weakLearnerParams.algo = Regression * Force boostingStrategy.weakLearnerParams.impurity = impurity.Variance * Only persist data if not yet persisted (since it causes an error if persisted twice) BoostingStrategy * numEstimators: renamed to numIterations * removed subsamplingRate (duplicated by Strategy) * removed categoricalFeaturesInfo since it belongs with the weak learner params (since boosting can be oblivious to feature type) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added assertValid() method * Updated defaultParams() method and eliminated defaultWeakLearnerParams() since that belongs in Strategy Strategy (for DecisionTree) * Changed algo to var (not val) and added BeanProperty, with overload taking String argument * Added setCategoricalFeaturesInfo method taking Java Map. * Cleaned up assertValid * Changed val’s to def’s since parameters can now be changed. CC: manishamde mengxr codedeft Author: Joseph K. Bradley <joseph@databricks.com> Closes #3094 from jkbradley/gbt-api and squashes the following commits: 7a27e22 [Joseph K. Bradley] scalastyle fix 52013d5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into gbt-api e9b8410 [Joseph K. Bradley] Summary of changes
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala169
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala78
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala34
4 files changed, 145 insertions, 187 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
index 1a847201ce..f729344a68 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -17,30 +17,49 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
-
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.loss.Losses
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
- * A class that implements gradient boosting for regression and binary classification problems.
+ * A class that implements Stochastic Gradient Boosting
+ * for regression and binary classification problems.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes:
+ * - This currently can be run with several loss functions. However, only SquaredError is
+ * fully supported. Specifically, the loss function should be used to compute the gradient
+ * (to re-label training instances on each iteration) and to weight weak hypotheses.
+ * Currently, gradients are computed correctly for the available loss functions,
+ * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ * Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
* @param boostingStrategy Parameters for the gradient boosting algorithm
*/
@Experimental
class GradientBoosting (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+ boostingStrategy.weakLearnerParams.algo = Regression
+ boostingStrategy.weakLearnerParams.impurity = impurity.Variance
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ boostingStrategy.weakLearnerParams.numClassesForClassification =
+ boostingStrategy.numClassesForClassification
+
+ boostingStrategy.assertValid()
+
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -51,6 +70,7 @@ class GradientBoosting (
algo match {
case Regression => GradientBoosting.boost(input, boostingStrategy)
case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoosting.boost(remappedInput, boostingStrategy)
case _ =>
@@ -118,120 +138,32 @@ object GradientBoosting extends Logging {
}
/**
- * Method to train a gradient boosting binary classification model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
*/
- def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting regression model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ train(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
*/
def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainClassifier(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
*/
def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainRegressor(input.rdd, boostingStrategy)
}
-
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
@@ -247,15 +179,17 @@ object GradientBoosting extends Logging {
timer.start("init")
// Initialize gradient boosting parameters
- val numEstimators = boostingStrategy.numEstimators
- val baseLearners = new Array[DecisionTreeModel](numEstimators)
- val baseLearnerWeights = new Array[Double](numEstimators)
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
val strategy = boostingStrategy.weakLearnerParams
// Cache input
- input.persist(StorageLevel.MEMORY_AND_DISK)
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
timer.stop("init")
@@ -264,7 +198,7 @@ object GradientBoosting extends Logging {
logDebug("##########")
var data = input
- // 1. Initialize tree
+ // Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(strategy).train(data)
baseLearners(0) = firstTreeModel
@@ -280,7 +214,7 @@ object GradientBoosting extends Logging {
point.features))
var m = 1
- while (m < numEstimators) {
+ while (m < numIterations) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
@@ -289,6 +223,9 @@ object GradientBoosting extends Logging {
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
@@ -305,8 +242,6 @@ object GradientBoosting extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- // 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 501d9ff9ea..abbda040bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -21,7 +21,6 @@ import scala.beans.BeanProperty
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
/**
@@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
+ * @param numIterations Number of iterations of boosting. In other words, the number of
+ * weak hypotheses used in the final model.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
+ * This setting overrides any setting in [[weakLearnerParams]].
* Default value is 2 (binary classification).
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
- * number of discrete values they take. For example, an entry (n ->
- * k) implies the feature n is categorical with k categories 0,
- * 1, 2, ... , k-1. It's important to note that features are
- * zero-indexed.
* @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
* supported.
*/
@Experimental
case class BoostingStrategy(
// Required boosting parameters
- algo: Algo,
- @BeanProperty var numEstimators: Int,
+ @BeanProperty var algo: Algo,
+ @BeanProperty var numIterations: Int,
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var learningRate: Double = 0.1,
- @BeanProperty var subsamplingRate: Double = 1.0,
@BeanProperty var numClassesForClassification: Int = 2,
- @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@BeanProperty var weakLearnerParams: Strategy) extends Serializable {
- require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
- s"$learningRate.")
- require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
- s"$learningRate.")
-
// Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
weakLearnerParams.numClassesForClassification = numClassesForClassification
- weakLearnerParams.subsamplingRate = subsamplingRate
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
+ private[tree] def assertValid(): Unit = {
+ algo match {
+ case Classification =>
+ require(numClassesForClassification == 2)
+ case Regression =>
+ // nothing
+ case _ =>
+ throw new IllegalArgumentException(
+ s"BoostingStrategy given invalid algo parameter: $algo." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(learningRate > 0 && learningRate <= 1,
+ "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+ }
}
@Experimental
@@ -82,28 +93,17 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStrategy = defaultWeakLearnerParams(algo)
+ def defaultParams(algo: String): BoostingStrategy = {
+ val treeStrategy = Strategy.defaultStrategy("Regression")
+ treeStrategy.maxDepth = 3
algo match {
- case Classification =>
- new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
- case Regression =>
- new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+ case "Classification" =>
+ new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ case "Regression" =>
+ new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
+ weakLearnerParams = treeStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
}
}
-
- /**
- * Returns default configuration for the weak learner (decision tree) algorithm
- * @param algo Learning goal. Supported:
- * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @return Configuration for weak learner
- */
- def defaultWeakLearnerParams(algo: Algo): Strategy = {
- // Note: Regression tree used even for classification for GBT.
- new Strategy(Regression, Variance, 3)
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d09295c507..b5b1f82177 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
*/
@Experimental
class Strategy (
- val algo: Algo,
+ @BeanProperty var algo: Algo,
@BeanProperty var impurity: Impurity,
@BeanProperty var maxDepth: Int,
@BeanProperty var numClassesForClassification: Int = 2,
@@ -85,17 +85,9 @@ class Strategy (
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- if (algo == Classification) {
- require(numClassesForClassification >= 2)
- }
- require(minInstancesPerNode >= 1,
- s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
- require(maxMemoryInMB <= 10240,
- s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
- val isMulticlassClassification =
+ def isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
- val isMulticlassWithCategoricalFeatures
+ def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
@@ -113,6 +105,23 @@ class Strategy (
}
/**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Sets categoricalFeaturesInfo using a Java Map.
+ */
+ def setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+ setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
+ /**
* Check validity of parameters.
* Throws exception if invalid.
*/
@@ -143,6 +152,26 @@ class Strategy (
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+}
+
+@Experimental
+object Strategy {
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo "Classification" or "Regression"
+ */
+ def defaultStrategy(algo: String): Strategy = algo match {
+ case "Classification" =>
+ new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+ numClassesForClassification = 2)
+ case "Regression" =>
+ new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+ numClassesForClassification = 0)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
index 970fff8221..99a02eda60 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -22,9 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
-import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.LocalSparkContext
@@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext
class GradientBoostingSuite extends FunSuite with LocalSparkContext {
test("Regression with continuous features: SquaredError") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, 1, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
test("Regression with continuous features: Absolute Error") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
}
-
test("Binary classification with continuous features: Log Loss") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
@@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
object GradientBoostingSuite {
// Combinations for estimators, learning rates and subsamplingRate
- val testCombinations
- = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+ val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
}