aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-20 00:48:59 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-20 00:48:59 -0800
commit15cacc81240eed8834b4730c5c6dc3238f003465 (patch)
tree8cf0c0d947a8c0de655559c471095e04cf232a06 /examples/src
parente216ffaead983274428052caa992b20760b2c5e0 (diff)
downloadspark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.gz
spark-15cacc81240eed8834b4730c5c6dc3238f003465.tar.bz2
spark-15cacc81240eed8834b4730c5c6dc3238f003465.zip
[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc
There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent. 1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java (renamed from examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java)18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala18
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala (renamed from examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala)18
3 files changed, 22 insertions, 32 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
index 1af2067b2b..4a5ac404ea 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -27,18 +27,18 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
-import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;
/**
* Classification and regression using gradient-boosted decision trees.
*/
-public final class JavaGradientBoostedTrees {
+public final class JavaGradientBoostedTreesRunner {
private static void usage() {
- System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
+ System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" +
" <Classification/Regression>");
System.exit(-1);
}
@@ -55,7 +55,7 @@ public final class JavaGradientBoostedTrees {
if (args.length > 2) {
usage();
}
- SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
@@ -64,7 +64,7 @@ public final class JavaGradientBoostedTrees {
// Note: All features are treated as continuous.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
boostingStrategy.setNumIterations(10);
- boostingStrategy.weakLearnerParams().setMaxDepth(5);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
if (algo.equals("Classification")) {
// Compute the number of classes from the data.
@@ -73,10 +73,10 @@ public final class JavaGradientBoostedTrees {
return p.label();
}
}).countByValue().size();
- boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public final class JavaGradientBoostedTrees {
System.out.println("Learned classification tree model:\n" + model);
} else if (algo.equals("Regression")) {
// Train a GradientBoosting model for classification.
- final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> predictionAndLabel =
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 63f02cf7b9..98f9d1689c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -352,21 +352,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
-
- /**
- * Calculates the mean squared error for regression.
- */
private[mllib] def meanSquaredError(
- tree: WeightedEnsembleModel,
+ model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
data.map { y =>
- val err = tree.predict(y.features) - y.label
+ val err = model.predict(y.features) - y.label
err * err
}.mean()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 9b6db01448..1def8b45a2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -21,21 +21,21 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
- * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify categoricalFeaturesInfo.
*/
-object GradientBoostedTrees {
+object GradientBoostedTreesRunner {
case class Params(
input: String = null,
@@ -93,24 +93,24 @@ object GradientBoostedTrees {
def run(params: Params) {
- val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
val sc = new SparkContext(conf)
- println(s"GradientBoostedTrees with parameters:\n$params")
+ println(s"GradientBoostedTreesRunner with parameters:\n$params")
// Load training and test data and cache it.
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
- boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.treeStrategy.numClassesForClassification = numClasses
boostingStrategy.numIterations = params.numIterations
- boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+ boostingStrategy.treeStrategy.maxDepth = params.maxDepth
val randomSeed = Utils.random.nextInt()
if (params.algo == "Classification") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
@@ -127,7 +127,7 @@ object GradientBoostedTrees {
println(s"Test accuracy = $testAccuracy")
} else if (params.algo == "Regression") {
val startTime = System.nanoTime()
- val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
val elapsedTime = (System.nanoTime() - startTime) / 1e9
println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {