diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-20 00:48:59 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-20 00:48:59 -0800 |
commit | 15cacc81240eed8834b4730c5c6dc3238f003465 (patch) | |
tree | 8cf0c0d947a8c0de655559c471095e04cf232a06 /examples/src/main/java | |
parent | e216ffaead983274428052caa992b20760b2c5e0 (diff) | |
download | spark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.gz spark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.bz2 spark-15cacc81240eed8834b4730c5c6dc3238f003465.zip |
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent.
1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly.
1. GradientBoosting -> GradientBoostedTrees
1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy
1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.)
1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train`
1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method.
1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting.
1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError
1. doc updates
manishamde jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #3374 from mengxr/SPARK-4486 and squashes the following commits:
7097251 [Xiangrui Meng] address joseph's comments
98dea09 [Xiangrui Meng] address manish's comments
4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy
ea4c467 [Xiangrui Meng] fix unit tests
751da4e [Xiangrui Meng] rename class method train -> run
19030a5 [Xiangrui Meng] update boosting public APIs
Diffstat (limited to 'examples/src/main/java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java) | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java index 1af2067b2b..4a5ac404ea 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoosting; +import org.apache.spark.mllib.tree.GradientBoostedTrees; import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; import org.apache.spark.mllib.util.MLUtils; /** * Classification and regression using gradient-boosted decision trees. */ -public final class JavaGradientBoostedTrees { +public final class JavaGradientBoostedTreesRunner { private static void usage() { - System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" + + System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" + " <Classification/Regression>"); System.exit(-1); } @@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees { if (args.length > 2) { usage(); } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); @@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees { // Note: All features are treated as continuous. BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); boostingStrategy.setNumIterations(10); - boostingStrategy.weakLearnerParams().setMaxDepth(5); + boostingStrategy.treeStrategy().setMaxDepth(5); if (algo.equals("Classification")) { // Compute the number of classes from the data. @@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees { return p.label(); } }).countByValue().size(); - boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD<Double, Double> predictionAndLabel = @@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees { System.out.println("Learned classification tree model:\n" + model); } else if (algo.equals("Regression")) { // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); // Evaluate model on training instances and compute training error JavaPairRDD<Double, Double> predictionAndLabel = |