aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java)18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala (renamed from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala)18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala)139
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala50
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala158
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala178
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala)91
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala14
20 files changed, 382 insertions, 437 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
index 1af2067b2b..4a5ac404ea 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;
/**
* Classification and regression using gradient-boosted decision trees.
*/
-public final class JavaGradientBoostedTrees {
+public final class JavaGradientBoostedTreesRunner {
private static void usage() {
- System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
@@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees {
if (args.length > 2) {
usage();
}
- SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
@@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees {
// Note: All features are treated as continuous.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
boostingStrategy.setNumIterations(10);
- boostingStrategy.weakLearnerParams().setMaxDepth(5);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
if (algo.equals("Classification")) {
// Compute the number of classes from the data.
@@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees {
return p.label();
}
}).countByValue().size();
- boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 63f02cf7b9..98f9d1689c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -352,21 +352,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
-
- /**
- * Calculates the mean squared error for regression.
- */
private[mllib] def meanSquaredError(
- tree: WeightedEnsembleModel,
+ model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
- val err = tree.predict(y.features) - y.label
+ val err = model.predict(y.features) - y.label
err * err
}.mean()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 9b6db01448..1def8b45a2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -21,21 +21,21 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
- * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify categoricalFeaturesInfo.
*/
-object GradientBoostedTrees {
+object GradientBoostedTreesRunner {
case class Params(
input: String = null,
@@ -93,24 +93,24 @@ object GradientBoostedTrees {
def run(params: Params) {
- val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
val sc = new SparkContext(conf)
- println(s"GradientBoostedTrees with parameters:\n$params")
+ println(s"GradientBoostedTreesRunner with parameters:\n$params")
// Load training and test data and cache it.
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
- boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.treeStrategy.numClassesForClassification = numClasses
boostingStrategy.numIterations = params.numIterations
- boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+ boostingStrategy.treeStrategy.maxDepth = params.maxDepth
val randomSeed = Utils.random.nextInt()
if (params.algo == "Classification") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
@@ -127,7 +127,7 @@ object GradientBoostedTrees {
println(s"Test accuracy = $testAccuracy")
} else if (params.algo == "Regression") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 78acc17f90..3d91867c89 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+ def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
- val rfModel = rf.train(input)
- rfModel.weakHypotheses(0)
+ val rfModel = rf.run(input)
+ rfModel.trees(0)
}
+ /**
+ * Trains a decision tree model over an RDD. This is deprecated because it hides the static
+ * methods with the same name in Java.
+ */
+ @deprecated("Please use DecisionTree.run instead.", "1.2.0")
+ def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
}
object DecisionTree extends Serializable with Logging {
@@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index f729344a68..cb4ddfc814 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -21,18 +21,17 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
- * A class that implements Stochastic Gradient Boosting
- * for regression and binary classification problems.
+ * A class that implements Stochastic Gradient Boosting for regression and binary classification.
*
* The implementation is based upon:
* J.H. Friedman. "Stochastic Gradient Boosting." 1999.
@@ -45,146 +44,92 @@ import org.apache.spark.storage.StorageLevel
* but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
* Running with those losses will likely behave reasonably, but lacks the same guarantees.
*
- * @param boostingStrategy Parameters for the gradient boosting algorithm
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
*/
@Experimental
-class GradientBoosting (
- private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
-
- boostingStrategy.weakLearnerParams.algo = Regression
- boostingStrategy.weakLearnerParams.impurity = impurity.Variance
-
- // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- boostingStrategy.weakLearnerParams.numClassesForClassification =
- boostingStrategy.numClassesForClassification
-
- boostingStrategy.assertValid()
+class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
+ extends Serializable with Logging {
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a gradient boosted trees model that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
+ def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoosting.boost(input, boostingStrategy)
+ case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoosting.boost(remappedInput, boostingStrategy)
+ GradientBoostedTrees.boost(remappedInput, boostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
+ */
+ def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ run(input.rdd)
+ }
}
-object GradientBoosting extends Logging {
+object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model.
*
- * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
- * is recommended to clearly specify regression.
- * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
- * is recommended to clearly specify regression.
- *
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a gradient boosted trees model that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- new GradientBoosting(boostingStrategy).train(input)
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ new GradientBoostedTrees(boostingStrategy).run(input)
}
/**
- * Method to train a gradient boosting classification model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainClassifier(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
- require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting regression model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param boostingStrategy Configuration options for the boosting algorithm.
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainRegressor(
- input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- val algo = boostingStrategy.algo
- require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
*/
def train(
- input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- train(input.rdd, boostingStrategy)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
- */
- def trainClassifier(
- input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- trainClassifier(input.rdd, boostingStrategy)
- }
-
- /**
- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
- */
- def trainRegressor(
input: JavaRDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
- trainRegressor(input.rdd, boostingStrategy)
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ train(input.rdd, boostingStrategy)
}
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param boostingStrategy boosting parameters
- * @return
+ * @return a gradient boosted trees model that can be used for prediction
*/
private def boost(
input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
+ boostingStrategy.assertValid()
+
// Initialize gradient boosting parameters
val numIterations = boostingStrategy.numIterations
val baseLearners = new Array[DecisionTreeModel](numIterations)
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
- val strategy = boostingStrategy.weakLearnerParams
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ treeStrategy.algo = Regression
+ treeStrategy.impurity = Variance
+ treeStrategy.assertValid()
// Cache input
if (input.getStorageLevel == StorageLevel.NONE) {
@@ -200,11 +145,10 @@ object GradientBoosting extends Logging {
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(strategy).train(data)
+ val firstTreeModel = new DecisionTree(treeStrategy).run(data)
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
- val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
- Sum)
+ val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
@@ -219,7 +163,7 @@ object GradientBoosting extends Logging {
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
- val model = new DecisionTree(strategy).train(data)
+ val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
@@ -228,8 +172,8 @@ object GradientBoosting extends Logging {
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
- val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1), Regression, Sum)
+ val partialModel = new GradientBoostedTreesModel(
+ Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
@@ -242,8 +186,7 @@ object GradientBoosting extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
- new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
-
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 9683916d9b..ca0b6eea9a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -17,18 +17,18 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.JavaConverters._
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
+ TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -79,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+ def run(input: RDD[LabeledPoint]): RandomForestModel = {
val timer = new TimeTracker()
@@ -212,8 +212,7 @@ private class RandomForest (
}
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- val treeWeights = Array.fill[Double](numTrees)(1.0)
- new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
+ new RandomForestModel(strategy.algo, trees)
}
}
@@ -234,18 +233,18 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -322,18 +321,18 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return WeightedEnsembleModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -387,7 +386,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): WeightedEnsembleModel = {
+ seed: Int): RandomForestModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
3 * totalBins
}
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index abbda040bd..e703adbdbf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
/**
* :: Experimental ::
- * Stores all the configuration options for the boosting algorithms
- * @param algo Learning goal. Supported:
- * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]].
+ *
+ * @param treeStrategy Parameters for the tree algorithm. We support regression and binary
+ * classification for boosting. Impurity setting will be ignored.
+ * @param loss Loss function used for minimization during gradient boosting.
* @param numIterations Number of iterations of boosting. In other words, the number of
* weak hypotheses used in the final model.
- * @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * This setting overrides any setting in [[weakLearnerParams]].
- * Default value is 2 (binary classification).
- * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
- * supported.
*/
@Experimental
case class BoostingStrategy(
// Required boosting parameters
- @BeanProperty var algo: Algo,
- @BeanProperty var numIterations: Int,
+ @BeanProperty var treeStrategy: Strategy,
@BeanProperty var loss: Loss,
// Optional boosting parameters
- @BeanProperty var learningRate: Double = 0.1,
- @BeanProperty var numClassesForClassification: Int = 2,
- @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
-
- // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- weakLearnerParams.numClassesForClassification = numClassesForClassification
-
- /**
- * Sets Algorithm using a String.
- */
- def setAlgo(algo: String): Unit = algo match {
- case "Classification" => setAlgo(Classification)
- case "Regression" => setAlgo(Regression)
- }
+ @BeanProperty var numIterations: Int = 100,
+ @BeanProperty var learningRate: Double = 0.1) extends Serializable {
/**
* Check validity of parameters.
* Throws exception if invalid.
*/
private[tree] def assertValid(): Unit = {
- algo match {
+ treeStrategy.algo match {
case Classification =>
- require(numClassesForClassification == 2)
+ require(treeStrategy.numClassesForClassification == 2,
+ "Only binary classification is supported for boosting.")
case Regression =>
// nothing
case _ =>
throw new IllegalArgumentException(
- s"BoostingStrategy given invalid algo parameter: $algo." +
+ s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." +
s" Valid settings are: Classification, Regression.")
}
require(learningRate > 0 && learningRate <= 1,
@@ -94,14 +76,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy("Regression")
+ val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
algo match {
case "Classification" =>
- new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ treeStrategy.numClassesForClassification = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
- new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
- weakLearnerParams = treeStrategy)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
index 82889dc00c..b5bf732d1b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -17,14 +17,10 @@
package org.apache.spark.mllib.tree.configuration
-import org.apache.spark.annotation.DeveloperApi
-
/**
- * :: Experimental ::
* Enum to select ensemble combining strategy for base learners
*/
-@DeveloperApi
-object EnsembleCombiningStrategy extends Enumeration {
+private[tree] object EnsembleCombiningStrategy extends Enumeration {
type EnsembleCombiningStrategy = Value
- val Sum, Average = Value
+ val Average, Sum, Vote = Value
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index b5b1f82177..d75f38433c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -157,6 +157,13 @@ class Strategy (
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+
+ /** Returns a shallow copy of this instance. */
+ def copy: Strategy = {
+ new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+ }
}
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d111ffe30e..e828866809 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -42,7 +42,7 @@ object AbsoluteError extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}
@@ -55,7 +55,7 @@ object AbsoluteError extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
val sumOfAbsolutes = data.map { y =>
val err = model.predict(y.features) - y.label
math.abs(err)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 6f3d4340f0..8b8adb44ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -42,7 +42,7 @@ object LogLoss extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
val prediction = model.predict(point.features)
1.0 / (1.0 + math.exp(-prediction)) - point.label
@@ -56,7 +56,7 @@ object LogLoss extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
wrongPredictions / data.count
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 5580866c87..4bca9039eb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -36,7 +36,7 @@ trait Loss extends Serializable {
* @return Loss gradient.
*/
def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double
/**
@@ -47,6 +47,6 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 4349fefef2..cfe395b1d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
/**
@@ -43,7 +43,7 @@ object SquaredError extends Loss {
* @return Loss gradient
*/
override def gradient(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
point: LabeledPoint): Double = {
model.predict(point.features) - point.label
}
@@ -56,7 +56,7 @@ object SquaredError extends Loss {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return
*/
- override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = model.predict(y.features) - y.label
err * err
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ac4d02ee39..a576096306 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,11 @@
package org.apache.spark.mllib.tree.model
-import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.Vector
/**
* :: Experimental ::
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
deleted file mode 100644
index 7b052d9163..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
-import org.apache.spark.rdd.RDD
-
-import scala.collection.mutable
-
-@Experimental
-class WeightedEnsembleModel(
- val weakHypotheses: Array[DecisionTreeModel],
- val weakHypothesisWeights: Array[Double],
- val algo: Algo,
- val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
-
- require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
- s". Number of weakHypotheses = $weakHypotheses")
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- private def predictRaw(features: Vector): Double = {
- val treePredictions = weakHypotheses.map(learner => learner.predict(features))
- if (numWeakHypotheses == 1){
- treePredictions(0)
- } else {
- var prediction = treePredictions(0)
- var index = 1
- while (index < numWeakHypotheses) {
- prediction += weakHypothesisWeights(index) * treePredictions(index)
- index += 1
- }
- prediction
- }
- }
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- private def predictBySumming(features: Vector): Double = {
- algo match {
- case Regression => predictRaw(features)
- case Classification => {
- // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
- if (predictRaw(features) > 0 ) 1.0 else 0.0
- }
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown algo parameter: $algo.")
- }
- }
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- private def predictByAveraging(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- weakHypotheses.foreach { learner =>
- val prediction = learner.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
- }
- }
-
-
- /**
- * Predict values for a single data point using the model trained.
- *
- * @param features array representing a single data point
- * @return predicted category from the trained model
- */
- def predict(features: Vector): Double = {
- combiningStrategy match {
- case Sum => predictBySumming(features)
- case Average => predictByAveraging(features)
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = {
- algo match {
- case Classification =>
- s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
- case Regression =>
- s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
- case _ => throw new IllegalArgumentException(
- s"WeightedEnsembleModel given unknown algo parameter: $algo.")
- }
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
- /**
- * Get number of trees in forest.
- */
- def numWeakHypotheses: Int = weakHypotheses.size
-
- // TODO: Remove these helpers methods once class is generalized to support any base learning
- // algorithms.
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
new file mode 100644
index 0000000000..22997110de
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Represents a random forest model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ */
+@Experimental
+class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ combiningStrategy = if (algo == Classification) Vote else Average) {
+
+ require(trees.forall(_.algo == algo))
+}
+
+/**
+ * :: Experimental ::
+ * Represents a gradient boosted trees model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ */
+@Experimental
+class GradientBoostedTreesModel(
+ override val algo: Algo,
+ override val trees: Array[DecisionTreeModel],
+ override val treeWeights: Array[Double])
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+
+ require(trees.size == treeWeights.size)
+}
+
+/**
+ * Represents a tree ensemble model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ * @param combiningStrategy strategy for combining the predictions, not used for regression.
+ */
+private[tree] sealed class TreeEnsembleModel(
+ protected val algo: Algo,
+ protected val trees: Array[DecisionTreeModel],
+ protected val treeWeights: Array[Double],
+ protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.")
+
+ private val sumWeights = math.max(treeWeights.sum, 1e-15)
+
+ /**
+ * Predicts for a single data point using the weighted sum of ensemble predictions.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ val treePredictions = trees.map(_.predict(features))
+ blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
+ }
+
+ /**
+ * Classifies a single data point based on (weighted) majority votes.
+ */
+ private def predictByVoting(features: Vector): Double = {
+ val votes = mutable.Map.empty[Int, Double]
+ trees.view.zip(treeWeights).foreach { case (tree, weight) =>
+ val prediction = tree.predict(features).toInt
+ votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
+ }
+ votes.maxBy(_._2)._1
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ (algo, combiningStrategy) match {
+ case (Regression, Sum) =>
+ predictBySumming(features)
+ case (Regression, Average) =>
+ predictBySumming(features) / sumWeights
+ case (Classification, Sum) => // binary classification
+ val prediction = predictBySumming(features)
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (prediction > 0.0) 1.0 else 0.0
+ case (Classification, Vote) =>
+ predictByVoting(features)
+ case _ =>
+ throw new IllegalArgumentException(
+ "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " +
+ s"($algo, $combiningStrategy).")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]].
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+ predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
+ }
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"TreeEnsembleModel classifier with $numTrees trees\n"
+ case Regression =>
+ s"TreeEnsembleModel regressor with $numTrees trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"TreeEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numTrees: Int = trees.size
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = trees.map(_.numNodes).sum
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 2c281a1ee7..9925aae441 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -74,7 +74,7 @@ public class JavaDecisionTreeSuite implements Serializable {
maxBins, categoricalFeaturesInfo);
DecisionTree learner = new DecisionTree(strategy);
- DecisionTreeModel model = learner.train(rdd.rdd());
+ DecisionTreeModel model = learner.run(rdd.rdd());
int numCorrect = validatePrediction(arr, model);
Assert.assertTrue(numCorrect == rdd.count());
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index effb7b8259..8972c229b7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.util.StatCounter
import scala.collection.mutable
@@ -48,7 +48,7 @@ object EnsembleTestHelper {
}
def validateClassifier(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
input: Seq[LabeledPoint],
requiredAccuracy: Double) {
val predictions = input.map(x => model.predict(x.features))
@@ -60,17 +60,27 @@ object EnsembleTestHelper {
s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}
+ /**
+ * Validates a tree ensemble model for regression.
+ */
def validateRegressor(
- model: WeightedEnsembleModel,
+ model: TreeEnsembleModel,
input: Seq[LabeledPoint],
- requiredMSE: Double) {
+ required: Double,
+ metricName: String = "mse") {
val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
+ prediction - label
+ }
+ val metric = metricName match {
+ case "mse" =>
+ errors.map(err => err * err).sum / errors.size
+ case "mae" =>
+ errors.map(math.abs).sum / errors.size
+ }
+
+ assert(metric <= required,
+ s"validateRegressor calculated $metricName $metric but required $required.")
}
def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 84de40103d..f3f8eff2db 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -23,104 +23,95 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Variance
-import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
- * Test suite for [[GradientBoosting]].
+ * Test suite for [[GradientBoostedTrees]].
*/
-class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext {
+class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
- GradientBoostingSuite.testCombinations.foreach {
+ GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
- val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val rdd = sc.parallelize(arr, 2)
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
- subsamplingRate = subsamplingRate)
-
- val dt = DecisionTree.train(remappedInput, treeStrategy)
-
- val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
- learningRate, 1, treeStrategy)
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
- val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numIterations)
- val gbtTree = gbt.weakHypotheses(0)
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+ assert(gbt.trees.size === numIterations)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
// Make sure trees are the same.
- assert(gbtTree.toString == dt.toString)
+ assert(gbt.trees.head.toString == dt.toString)
}
}
test("Regression with continuous features: Absolute Error") {
- GradientBoostingSuite.testCombinations.foreach {
+ GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
- val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val rdd = sc.parallelize(arr, 2)
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
- subsamplingRate = subsamplingRate)
-
- val dt = DecisionTree.train(remappedInput, treeStrategy)
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate)
- val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
- learningRate, numClassesForClassification = 2, treeStrategy)
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
- val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numIterations)
- val gbtTree = gbt.weakHypotheses(0)
+ assert(gbt.trees.size === numIterations)
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.85, "mae")
- EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
// Make sure trees are the same.
- assert(gbtTree.toString == dt.toString)
+ assert(gbt.trees.head.toString == dt.toString)
}
}
test("Binary classification with continuous features: Log Loss") {
- GradientBoostingSuite.testCombinations.foreach {
+ GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
- val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val rdd = sc.parallelize(arr, 2)
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty,
subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate)
- val dt = DecisionTree.train(remappedInput, treeStrategy)
-
- val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
- learningRate, numClassesForClassification = 2, treeStrategy)
-
- val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numIterations)
- val gbtTree = gbt.weakHypotheses(0)
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+ assert(gbt.trees.size === numIterations)
EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val ensembleStrategy = treeStrategy.copy
+ ensembleStrategy.algo = Regression
+ ensembleStrategy.impurity = Variance
+ val dt = DecisionTree.train(remappedInput, ensembleStrategy)
+
// Make sure trees are the same.
- assert(gbtTree.toString == dt.toString)
+ assert(gbt.trees.head.toString == dt.toString)
}
}
}
-object GradientBoostingSuite {
+object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 2734e089d6..90a8c2dfda 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -41,8 +41,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.weakHypotheses.size === 1)
- val rfTree = rf.weakHypotheses(0)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -65,7 +65,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ useNodeIdCache = true)
binaryClassificationTestWithContinuousFeatures(strategy)
}
@@ -76,8 +77,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.weakHypotheses.size === 1)
- val rfTree = rf.weakHypotheses(0)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -175,7 +176,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
test("Binary classification with continuous features and node Id cache: subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ useNodeIdCache = true)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}