From 15cacc81240eed8834b4730c5c6dc3238f003465 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Nov 2014 00:48:59 -0800 Subject: [SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent. 1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 20 +- .../spark/mllib/tree/GradientBoostedTrees.scala | 192 ++++++++++++++++ .../apache/spark/mllib/tree/GradientBoosting.scala | 249 --------------------- .../org/apache/spark/mllib/tree/RandomForest.scala | 40 ++-- .../tree/configuration/BoostingStrategy.scala | 50 ++--- .../configuration/EnsembleCombiningStrategy.scala | 8 +- .../spark/mllib/tree/configuration/Strategy.scala | 7 + .../spark/mllib/tree/loss/AbsoluteError.scala | 6 +- .../org/apache/spark/mllib/tree/loss/LogLoss.scala | 6 +- .../org/apache/spark/mllib/tree/loss/Loss.scala | 6 +- .../spark/mllib/tree/loss/SquaredError.scala | 6 +- .../spark/mllib/tree/model/DecisionTreeModel.scala | 4 +- .../mllib/tree/model/WeightedEnsembleModel.scala | 158 ------------- .../mllib/tree/model/treeEnsembleModels.scala | 178 +++++++++++++++ .../spark/mllib/tree/JavaDecisionTreeSuite.java | 2 +- .../spark/mllib/tree/EnsembleTestHelper.scala | 30 ++- .../mllib/tree/GradientBoostedTreesSuite.scala | 117 ++++++++++ .../spark/mllib/tree/GradientBoostingSuite.scala | 126 ----------- .../spark/mllib/tree/RandomForestSuite.scala | 14 +- 19 files changed, 587 insertions(+), 632 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 78acc17f90..3d91867c89 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) - val rfModel = rf.train(input) - rfModel.weakHypotheses(0) + val rfModel = rf.run(input) + rfModel.trees(0) } + /** + * Trains a decision tree model over an RDD. This is deprecated because it hides the static + * methods with the same name in Java. + */ + @deprecated("Please use DecisionTree.run instead.", "1.2.0") + def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging { * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala new file mode 100644 index 0000000000..cb4ddfc814 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * A class that implements Stochastic Gradient Boosting for regression and binary classification. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ +@Experimental +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) + extends Serializable with Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return a gradient boosted trees model that can be used for prediction + */ + def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, boostingStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + run(input.rdd) + } +} + + +object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return a gradient boosted trees model that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + new GradientBoostedTrees(boostingStrategy).run(input) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] + */ + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + train(input.rdd, boostingStrategy) + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param boostingStrategy boosting parameters + * @return a gradient boosted trees model that can be used for prediction + */ + private def boost( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + treeStrategy.algo = Regression + treeStrategy.impurity = Variance + treeStrategy.assertValid() + + // Cache input + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + var data = input + + // Initialize tree + timer.start("building tree 0") + val firstTreeModel = new DecisionTree(treeStrategy).run(data) + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = 1.0 + val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + // psuedo-residual for second iteration + data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), + point.features)) + + var m = 1 + while (m < numIterations) { + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val model = new DecisionTree(treeStrategy).run(data) + timer.stop(s"building tree $m") + // Create partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + // Note: A model of type regression is used since we require raw prediction + val partialModel = new GradientBoostedTreesModel( + Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) + // Update data with pseudo-residuals + data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), + point.features)) + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala deleted file mode 100644 index f729344a68..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -/** - * :: Experimental :: - * A class that implements Stochastic Gradient Boosting - * for regression and binary classification problems. - * - * The implementation is based upon: - * J.H. Friedman. "Stochastic Gradient Boosting." 1999. - * - * Notes: - * - This currently can be run with several loss functions. However, only SquaredError is - * fully supported. Specifically, the loss function should be used to compute the gradient - * (to re-label training instances on each iteration) and to weight weak hypotheses. - * Currently, gradients are computed correctly for the available loss functions, - * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. - * Running with those losses will likely behave reasonably, but lacks the same guarantees. - * - * @param boostingStrategy Parameters for the gradient boosting algorithm - */ -@Experimental -class GradientBoosting ( - private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { - - boostingStrategy.weakLearnerParams.algo = Regression - boostingStrategy.weakLearnerParams.impurity = impurity.Variance - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - boostingStrategy.weakLearnerParams.numClassesForClassification = - boostingStrategy.numClassesForClassification - - boostingStrategy.assertValid() - - /** - * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return WeightedEnsembleModel that can be used for prediction - */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - algo match { - case Regression => GradientBoosting.boost(input, boostingStrategy) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoosting.boost(remappedInput, boostingStrategy) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } - } - -} - - -object GradientBoosting extends Logging { - - /** - * Method to train a gradient boosting model. - * - * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - * is recommended to clearly specify regression. - * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - * is recommended to clearly specify regression. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainClassifier( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] - */ - def train( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - train(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - */ - def trainClassifier( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainClassifier(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - */ - def trainRegressor( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainRegressor(input.rdd, boostingStrategy) - } - - /** - * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param boostingStrategy boosting parameters - * @return - */ - private def boost( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - - val timer = new TimeTracker() - timer.start("total") - timer.start("init") - - // Initialize gradient boosting parameters - val numIterations = boostingStrategy.numIterations - val baseLearners = new Array[DecisionTreeModel](numIterations) - val baseLearnerWeights = new Array[Double](numIterations) - val loss = boostingStrategy.loss - val learningRate = boostingStrategy.learningRate - val strategy = boostingStrategy.weakLearnerParams - - // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { - input.persist(StorageLevel.MEMORY_AND_DISK) - } - - timer.stop("init") - - logDebug("##########") - logDebug("Building tree 0") - logDebug("##########") - var data = input - - // Initialize tree - timer.start("building tree 0") - val firstTreeModel = new DecisionTree(strategy).train(data) - baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, - Sum) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) - // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") - - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) - - var m = 1 - while (m < numIterations) { - timer.start(s"building tree $m") - logDebug("###################################################") - logDebug("Gradient boosting tree iteration " + m) - logDebug("###################################################") - val model = new DecisionTree(strategy).train(data) - timer.stop(s"building tree $m") - // Create partial model - baseLearners(m) = model - // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. - // Technically, the weight should be optimized for the particular loss. - // However, the behavior should be reasonable, though not optimal. - baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1), Regression, Sum) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) - // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) - m += 1 - } - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) - - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 9683916d9b..ca0b6eea9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, + TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -79,9 +79,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() @@ -212,8 +212,7 @@ private class RandomForest ( } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - val treeWeights = Array.fill[Double](numTrees)(1.0) - new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) + new RandomForestModel(strategy.algo, trees) } } @@ -234,18 +233,18 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -322,18 +321,18 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -387,7 +386,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging { 3 * totalBins } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index abbda040bd..e703adbdbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** * :: Experimental :: - * Stores all the configuration options for the boosting algorithms - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. + * + * @param treeStrategy Parameters for the tree algorithm. We support regression and binary + * classification for boosting. Impurity setting will be ignored. + * @param loss Loss function used for minimization during gradient boosting. * @param numIterations Number of iterations of boosting. In other words, the number of * weak hypotheses used in the final model. - * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * This setting overrides any setting in [[weakLearnerParams]]. - * Default value is 2 (binary classification). - * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are - * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - @BeanProperty var algo: Algo, - @BeanProperty var numIterations: Int, + @BeanProperty var treeStrategy: Strategy, @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.numClassesForClassification = numClassesForClassification - - /** - * Sets Algorithm using a String. - */ - def setAlgo(algo: String): Unit = algo match { - case "Classification" => setAlgo(Classification) - case "Regression" => setAlgo(Regression) - } + @BeanProperty var numIterations: Int = 100, + @BeanProperty var learningRate: Double = 0.1) extends Serializable { /** * Check validity of parameters. * Throws exception if invalid. */ private[tree] def assertValid(): Unit = { - algo match { + treeStrategy.algo match { case Classification => - require(numClassesForClassification == 2) + require(treeStrategy.numClassesForClassification == 2, + "Only binary classification is supported for boosting.") case Regression => // nothing case _ => throw new IllegalArgumentException( - s"BoostingStrategy given invalid algo parameter: $algo." + + s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + s" Valid settings are: Classification, Regression.") } require(learningRate > 0 && learningRate <= 1, @@ -94,14 +76,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy("Regression") + val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 algo match { case "Classification" => - new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + treeStrategy.numClassesForClassification = 2 + new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => - new BoostingStrategy(Algo.withName(algo), 100, SquaredError, - weakLearnerParams = treeStrategy) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala index 82889dc00c..b5bf732d1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -17,14 +17,10 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.DeveloperApi - /** - * :: Experimental :: * Enum to select ensemble combining strategy for base learners */ -@DeveloperApi -object EnsembleCombiningStrategy extends Enumeration { +private[tree] object EnsembleCombiningStrategy extends Enumeration { type EnsembleCombiningStrategy = Value - val Sum, Average = Value + val Average, Sum, Vote = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b5b1f82177..d75f38433c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -157,6 +157,13 @@ class Strategy ( require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } + + /** Returns a shallow copy of this instance. */ + def copy: Strategy = { + new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + } } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d111ffe30e..e828866809 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -42,7 +42,7 @@ object AbsoluteError extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 } @@ -55,7 +55,7 @@ object AbsoluteError extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { val sumOfAbsolutes = data.map { y => val err = model.predict(y.features) - y.label math.abs(err) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 6f3d4340f0..8b8adb44ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -42,7 +42,7 @@ object LogLoss extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { val prediction = model.predict(point.features) 1.0 / (1.0 + math.exp(-prediction)) - point.label @@ -56,7 +56,7 @@ object LogLoss extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() wrongPredictions / data.count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 5580866c87..4bca9039eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -36,7 +36,7 @@ trait Loss extends Serializable { * @return Loss gradient. */ def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double /** @@ -47,6 +47,6 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 4349fefef2..cfe395b1d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -43,7 +43,7 @@ object SquaredError extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { model.predict(point.features) - point.label } @@ -56,7 +56,7 @@ object SquaredError extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = model.predict(y.features) - y.label err * err diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ac4d02ee39..a576096306 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala deleted file mode 100644 index 7b052d9163..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ -import org.apache.spark.rdd.RDD - -import scala.collection.mutable - -@Experimental -class WeightedEnsembleModel( - val weakHypotheses: Array[DecisionTreeModel], - val weakHypothesisWeights: Array[Double], - val algo: Algo, - val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { - - require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + - s". Number of weakHypotheses = $weakHypotheses") - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictRaw(features: Vector): Double = { - val treePredictions = weakHypotheses.map(learner => learner.predict(features)) - if (numWeakHypotheses == 1){ - treePredictions(0) - } else { - var prediction = treePredictions(0) - var index = 1 - while (index < numWeakHypotheses) { - prediction += weakHypothesisWeights(index) * treePredictions(index) - index += 1 - } - prediction - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictBySumming(features: Vector): Double = { - algo match { - case Regression => predictRaw(features) - case Classification => { - // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. - if (predictRaw(features) > 0 ) 1.0 else 0.0 - } - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - private def predictByAveraging(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - weakHypotheses.foreach { learner => - val prediction = learner.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size - } - } - - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - def predict(features: Vector): Double = { - combiningStrategy match { - case Sum => predictBySumming(features) - case Average => predictByAveraging(features) - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) - - /** - * Print a summary of the model. - */ - override def toString: String = { - algo match { - case Classification => - s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" - case Regression => - s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - - /** - * Get number of trees in forest. - */ - def numWeakHypotheses: Int = weakHypotheses.size - - // TODO: Remove these helpers methods once class is generalized to support any base learning - // algorithms. - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala new file mode 100644 index 0000000000..22997110de --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a random forest model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + */ +@Experimental +class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) { + + require(trees.forall(_.algo == algo)) +} + +/** + * :: Experimental :: + * Represents a gradient boosted trees model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + */ +@Experimental +class GradientBoostedTreesModel( + override val algo: Algo, + override val trees: Array[DecisionTreeModel], + override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + + require(trees.size == treeWeights.size) +} + +/** + * Represents a tree ensemble model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + * @param combiningStrategy strategy for combining the predictions, not used for regression. + */ +private[tree] sealed class TreeEnsembleModel( + protected val algo: Algo, + protected val trees: Array[DecisionTreeModel], + protected val treeWeights: Array[Double], + protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") + + private val sumWeights = math.max(treeWeights.sum, 1e-15) + + /** + * Predicts for a single data point using the weighted sum of ensemble predictions. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + val treePredictions = trees.map(_.predict(features)) + blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + } + + /** + * Classifies a single data point based on (weighted) majority votes. + */ + private def predictByVoting(features: Vector): Double = { + val votes = mutable.Map.empty[Int, Double] + trees.view.zip(treeWeights).foreach { case (tree, weight) => + val prediction = tree.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + weight + } + votes.maxBy(_._2)._1 + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + (algo, combiningStrategy) match { + case (Regression, Sum) => + predictBySumming(features) + case (Regression, Average) => + predictBySumming(features) / sumWeights + case (Classification, Sum) => // binary classification + val prediction = predictBySumming(features) + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (prediction > 0.0) 1.0 else 0.0 + case (Classification, Vote) => + predictByVoting(features) + case _ => + throw new IllegalArgumentException( + "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + + s"($algo, $combiningStrategy).") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. + */ + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + } + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"TreeEnsembleModel classifier with $numTrees trees\n" + case Regression => + s"TreeEnsembleModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"TreeEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 2c281a1ee7..9925aae441 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -74,7 +74,7 @@ public class JavaDecisionTreeSuite implements Serializable { maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); - DecisionTreeModel model = learner.train(rdd.rdd()); + DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index effb7b8259..8972c229b7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter import scala.collection.mutable @@ -48,7 +48,7 @@ object EnsembleTestHelper { } def validateClassifier( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], requiredAccuracy: Double) { val predictions = input.map(x => model.predict(x.features)) @@ -60,17 +60,27 @@ object EnsembleTestHelper { s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } + /** + * Validates a tree ensemble model for regression. + */ def validateRegressor( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, input: Seq[LabeledPoint], - requiredMSE: Double) { + required: Double, + metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => + prediction - label + } + val metric = metricName match { + case "mse" => + errors.map(err => err * err).sum / errors.size + case "mae" => + errors.map(math.abs).sum / errors.size + } + + assert(metric <= required, + s"validateRegressor calculated $metricName $metric but required $required.") } def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000..f3f8eff2db --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { + + test("Regression with continuous features: SquaredError") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val rdd = sc.parallelize(arr, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + + test("Regression with continuous features: Absolute Error") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val rdd = sc.parallelize(arr, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + EnsembleTestHelper.validateRegressor(gbt, arr, 0.85, "mae") + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + + test("Binary classification with continuous features: Log Loss") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val rdd = sc.parallelize(arr, 2) + + val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty, + subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val ensembleStrategy = treeStrategy.copy + ensembleStrategy.algo = Regression + ensembleStrategy.impurity = Variance + val dt = DecisionTree.train(remappedInput, ensembleStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + +} + +object GradientBoostedTreesSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala deleted file mode 100644 index 84de40103d..0000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} -import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} - -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[GradientBoosting]]. - */ -class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext { - - test("Regression with continuous features: SquaredError") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, 1, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Regression with continuous features: Absolute Error") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - - test("Binary classification with continuous features: Log Loss") { - GradientBoostingSuite.testCombinations.foreach { - case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - subsamplingRate = subsamplingRate) - - val dt = DecisionTree.train(remappedInput, treeStrategy) - - val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss, - learningRate, numClassesForClassification = 2, treeStrategy) - - val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) - assert(gbt.weakHypotheses.size === numIterations) - val gbtTree = gbt.weakHypotheses(0) - - EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) - - // Make sure trees are the same. - assert(gbtTree.toString == dt.toString) - } - } - -} - -object GradientBoostingSuite { - - // Combinations for estimators, learning rates and subsamplingRate - val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 2734e089d6..90a8c2dfda 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -41,8 +41,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -65,7 +65,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeatures(strategy) } @@ -76,8 +77,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.weakHypotheses.size === 1) - val rfTree = rf.weakHypotheses(0) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -175,7 +176,8 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { test("Binary classification with continuous features and node Id cache: subsampling features") { val categoricalFeaturesInfo = Map.empty[Int, Int] val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } -- cgit v1.2.3