aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java126
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala64
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala146
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala169
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala78
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala34
7 files changed, 462 insertions, 206 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000000..1af2067b2b
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ " <Classification/Regression>");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
+ @Override public Boolean call(Tuple2<Double, Double> pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD<Double, Double> predictionAndLabel =
+ data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
+ @Override public Double call(Tuple2<Double, Double> pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2<Double, Double, Double>() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 49751a3049..63f02cf7b9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -154,20 +154,30 @@ object DecisionTreeRunner {
}
}
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
- val sc = new SparkContext(conf)
-
- println(s"DecisionTreeRunner with parameters:\n$params")
-
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset, number of classes),
+ * where the number of classes is inferred from data (and set to 0 for Regression)
+ */
+ private[mllib] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: Algo,
+ fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
// Load training data and cache it.
- val origExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ val origExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
}
// For classification, re-index classes if needed.
- val (examples, classIndexMap, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -205,14 +215,14 @@ object DecisionTreeRunner {
}
// Create training, test sets.
- val splits = if (params.testInput != "") {
+ val splits = if (testInput != "") {
// Load testInput.
val numFeatures = examples.take(1)(0).features.size
- val origTestExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+ val origTestExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
- params.algo match {
+ algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
@@ -229,17 +239,31 @@ object DecisionTreeRunner {
}
} else {
// Split input into training, test.
- examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ examples.randomSplit(Array(1.0 - fracTest, fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
+
val numTraining = training.count()
val numTest = test.count()
-
println(s"numTraining = $numTraining, numTest = $numTest.")
examples.unpersist(blocking = false)
+ (training, test, numClasses)
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+ params.testInput, params.algo, params.fracTest)
+
val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
@@ -338,7 +362,9 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ private[mllib] def meanSquaredError(
+ tree: WeightedEnsembleModel,
+ data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
new file mode 100644
index 0000000000..9b6db01448
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTrees {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("<input>")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTrees with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
index 1a847201ce..f729344a68 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -17,30 +17,49 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
-
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.loss.Losses
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
- * A class that implements gradient boosting for regression and binary classification problems.
+ * A class that implements Stochastic Gradient Boosting
+ * for regression and binary classification problems.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes:
+ * - This currently can be run with several loss functions. However, only SquaredError is
+ * fully supported. Specifically, the loss function should be used to compute the gradient
+ * (to re-label training instances on each iteration) and to weight weak hypotheses.
+ * Currently, gradients are computed correctly for the available loss functions,
+ * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ * Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
* @param boostingStrategy Parameters for the gradient boosting algorithm
*/
@Experimental
class GradientBoosting (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+ boostingStrategy.weakLearnerParams.algo = Regression
+ boostingStrategy.weakLearnerParams.impurity = impurity.Variance
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ boostingStrategy.weakLearnerParams.numClassesForClassification =
+ boostingStrategy.numClassesForClassification
+
+ boostingStrategy.assertValid()
+
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -51,6 +70,7 @@ class GradientBoosting (
algo match {
case Regression => GradientBoosting.boost(input, boostingStrategy)
case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoosting.boost(remappedInput, boostingStrategy)
case _ =>
@@ -118,120 +138,32 @@ object GradientBoosting extends Logging {
}
/**
- * Method to train a gradient boosting binary classification model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
*/
- def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting regression model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ train(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
*/
def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainClassifier(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
*/
def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainRegressor(input.rdd, boostingStrategy)
}
-
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
@@ -247,15 +179,17 @@ object GradientBoosting extends Logging {
timer.start("init")
// Initialize gradient boosting parameters
- val numEstimators = boostingStrategy.numEstimators
- val baseLearners = new Array[DecisionTreeModel](numEstimators)
- val baseLearnerWeights = new Array[Double](numEstimators)
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
val strategy = boostingStrategy.weakLearnerParams
// Cache input
- input.persist(StorageLevel.MEMORY_AND_DISK)
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
timer.stop("init")
@@ -264,7 +198,7 @@ object GradientBoosting extends Logging {
logDebug("##########")
var data = input
- // 1. Initialize tree
+ // Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(strategy).train(data)
baseLearners(0) = firstTreeModel
@@ -280,7 +214,7 @@ object GradientBoosting extends Logging {
point.features))
var m = 1
- while (m < numEstimators) {
+ while (m < numIterations) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
@@ -289,6 +223,9 @@ object GradientBoosting extends Logging {
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
@@ -305,8 +242,6 @@ object GradientBoosting extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- // 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 501d9ff9ea..abbda040bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -21,7 +21,6 @@ import scala.beans.BeanProperty
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
/**
@@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
+ * @param numIterations Number of iterations of boosting. In other words, the number of
+ * weak hypotheses used in the final model.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
+ * This setting overrides any setting in [[weakLearnerParams]].
* Default value is 2 (binary classification).
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
- * number of discrete values they take. For example, an entry (n ->
- * k) implies the feature n is categorical with k categories 0,
- * 1, 2, ... , k-1. It's important to note that features are
- * zero-indexed.
* @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
* supported.
*/
@Experimental
case class BoostingStrategy(
// Required boosting parameters
- algo: Algo,
- @BeanProperty var numEstimators: Int,
+ @BeanProperty var algo: Algo,
+ @BeanProperty var numIterations: Int,
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var learningRate: Double = 0.1,
- @BeanProperty var subsamplingRate: Double = 1.0,
@BeanProperty var numClassesForClassification: Int = 2,
- @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@BeanProperty var weakLearnerParams: Strategy) extends Serializable {
- require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
- s"$learningRate.")
- require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
- s"$learningRate.")
-
// Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
weakLearnerParams.numClassesForClassification = numClassesForClassification
- weakLearnerParams.subsamplingRate = subsamplingRate
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
+ private[tree] def assertValid(): Unit = {
+ algo match {
+ case Classification =>
+ require(numClassesForClassification == 2)
+ case Regression =>
+ // nothing
+ case _ =>
+ throw new IllegalArgumentException(
+ s"BoostingStrategy given invalid algo parameter: $algo." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(learningRate > 0 && learningRate <= 1,
+ "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+ }
}
@Experimental
@@ -82,28 +93,17 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStrategy = defaultWeakLearnerParams(algo)
+ def defaultParams(algo: String): BoostingStrategy = {
+ val treeStrategy = Strategy.defaultStrategy("Regression")
+ treeStrategy.maxDepth = 3
algo match {
- case Classification =>
- new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
- case Regression =>
- new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+ case "Classification" =>
+ new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ case "Regression" =>
+ new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
+ weakLearnerParams = treeStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
}
}
-
- /**
- * Returns default configuration for the weak learner (decision tree) algorithm
- * @param algo Learning goal. Supported:
- * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @return Configuration for weak learner
- */
- def defaultWeakLearnerParams(algo: Algo): Strategy = {
- // Note: Regression tree used even for classification for GBT.
- new Strategy(Regression, Variance, 3)
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d09295c507..b5b1f82177 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
*/
@Experimental
class Strategy (
- val algo: Algo,
+ @BeanProperty var algo: Algo,
@BeanProperty var impurity: Impurity,
@BeanProperty var maxDepth: Int,
@BeanProperty var numClassesForClassification: Int = 2,
@@ -85,17 +85,9 @@ class Strategy (
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- if (algo == Classification) {
- require(numClassesForClassification >= 2)
- }
- require(minInstancesPerNode >= 1,
- s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
- require(maxMemoryInMB <= 10240,
- s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
- val isMulticlassClassification =
+ def isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
- val isMulticlassWithCategoricalFeatures
+ def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
@@ -113,6 +105,23 @@ class Strategy (
}
/**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Sets categoricalFeaturesInfo using a Java Map.
+ */
+ def setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+ setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
+ /**
* Check validity of parameters.
* Throws exception if invalid.
*/
@@ -143,6 +152,26 @@ class Strategy (
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+}
+
+@Experimental
+object Strategy {
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo "Classification" or "Regression"
+ */
+ def defaultStrategy(algo: String): Strategy = algo match {
+ case "Classification" =>
+ new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+ numClassesForClassification = 2)
+ case "Regression" =>
+ new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+ numClassesForClassification = 0)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
index 970fff8221..99a02eda60 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -22,9 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
-import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.LocalSparkContext
@@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext
class GradientBoostingSuite extends FunSuite with LocalSparkContext {
test("Regression with continuous features: SquaredError") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, 1, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
test("Regression with continuous features: Absolute Error") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
}
-
test("Binary classification with continuous features: Log Loss") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
@@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
object GradientBoostingSuite {
// Combinations for estimators, learning rates and subsamplingRate
- val testCombinations
- = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+ val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
}