aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-20 00:48:59 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-20 00:48:59 -0800
commit15cacc81240eed8834b4730c5c6dc3238f003465 (patch)
tree8cf0c0d947a8c0de655559c471095e04cf232a06 /examples/src/main/java
parente216ffaead983274428052caa992b20760b2c5e0 (diff)
downloadspark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.gz
spark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.bz2
spark-15cacc81240eed8834b4730c5c6dc3238f003465.zip
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent. 1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java)18
1 files changed, 9 insertions, 9 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
index 1af2067b2b..4a5ac404ea 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;
/**
* Classification and regression using gradient-boosted decision trees.
*/
-public final class JavaGradientBoostedTrees {
+public final class JavaGradientBoostedTreesRunner {
private static void usage() {
- System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
@@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees {
if (args.length > 2) {
usage();
}
- SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
@@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees {
// Note: All features are treated as continuous.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
boostingStrategy.setNumIterations(10);
- boostingStrategy.weakLearnerParams().setMaxDepth(5);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
if (algo.equals("Classification")) {
// Compute the number of classes from the data.
@@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees {
return p.label();
}
}).countByValue().size();
- boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =